Papers
arxiv:2408.14906

Writing in the Margins: Better Inference Pattern for Long Context Retrieval

Published on Aug 27
· Submitted by kiranr on Aug 28
#1 Paper of the day

Abstract

In this paper, we introduce Writing in the Margins (WiM), a new inference pattern for Large Language Models designed to optimize the handling of long input sequences in retrieval-oriented tasks. This approach leverages the chunked prefill of the key-value cache to perform segment-wise inference, which enables efficient processing of extensive contexts along with the generation and classification of intermediate information ("margins") that guide the model towards specific tasks. This method increases computational overhead marginally while significantly enhancing the performance of off-the-shelf models without the need for fine-tuning. Specifically, we observe that WiM provides an average enhancement of 7.5% in accuracy for reasoning skills (HotpotQA, MultiHop-RAG) and more than a 30.0% increase in the F1-score for aggregation tasks (CWE). Additionally, we show how the proposed pattern fits into an interactive retrieval design that provides end-users with ongoing updates about the progress of context processing, and pinpoints the integration of relevant information into the final response. We release our implementation of WiM using Hugging Face Transformers library at https://github.com/writer/writing-in-the-margins.

Community

Paper author Paper submitter

Congrats on the paper🔥Amazing work!

Amazing work, congratulations!

This is an automated message from the Librarian Bot. I found the following papers similar to this paper.

The following papers were recommended by the Semantic Scholar API

Please give a thumbs up to this comment if you found it helpful!

If you want recommendations for any Paper on Hugging Face checkout this Space

You can directly ask Librarian Bot for paper recommendations by tagging it in a comment: @librarian-bot recommend

Nicely done! Great paper!

This is really cool!!

However I have got myself very confused. Can someone explain why this is true, I think I am missing something simple.

By splitting a prompt of length L into N chunks, each
of size K, where N = L/K, the overall memory complexity
of prefilling is reduced from O(L^2) to O(LK).

In the chunked case, am I right to assume our first chunk would have memory cost K * K (?) As each of the K tokens in our chunk attends to the others in said chunk.

For the second chunk, this is now a(2K * K) as our K tokens in this chunk now attend to the prior K tokens also. I think this would continue for ...

Total Cost = Cost Chunk 1 + ... + Cost Chunk N
= K^2 + 2K^2 + ... + NK^2
= K^2(1+...+N)
= K^2 (N)(N+1)/2
= K^2 (L/K)(L/K+1)/2 * by substitution of N = L/K
= K^2 (L^2/K^2 + L/K)/2
= (L^2 + LK)/2

And I am confused why this is O(LK), so I think there must be some fundamental flaw in my understanding

·

Since chunked prefill is done "step by step" (each step is a forward pass through the model), the worst case memory complexity is allocated during the last step in which the last chunk is prefilled. For the last step, the memory complexity is the number of tokens in a single chunk (K) multiplied by the total number of tokens in the sequence (L).

Memory is always allocated and subsequently released after each forward pass (except for the KV-Cache and the model's parameters), that's why we don't "accumulate" the total cost over all chunks.

A video summary is now available here - https://youtu.be/JODc9ku5djA

·
Paper author

I love it!

Sign up or log in to comment

Models citing this paper 0

No model linking this paper

Cite arxiv.org/abs/2408.14906 in a model README.md to link it from this page.

Datasets citing this paper 0

No dataset linking this paper

Cite arxiv.org/abs/2408.14906 in a dataset README.md to link it from this page.

Spaces citing this paper 0

No Space linking this paper

Cite arxiv.org/abs/2408.14906 in a Space README.md to link it from this page.

Collections including this paper 17