Stop Wasting LLM Tokens. Batching your inputs together can lead… | by Tobias Schnabel | Aug, 2024

Published:


Batching your inputs together can lead to substantial savings without compromising on performance

Towards Data Science
Photo by Orgalux on Unsplash

If you use LLMs to annotate or process larger datasets, chances are that you’re not even realizing that you are wasting a lot of input tokens. As you repeatedly call an LLM to process text snippets or entire documents, your task instructions and static few-shot examples are repeated for every input example. Just like neatly stacking dishes saves space, batching inputs together can result in substantial savings.

Assume you want to tag a smaller document corpus of 1000 single-page documents with instructions and few-shot examples that are about half a page long. Annotating each document separately would cost you about 1M input tokens. However, if you annotated ten documents in the same call, you’d save about 300K input tokens (or 30%) because we don’t have to repeat instructions! As we’ll show in the example below, this can often happen with minimal performance loss (or even performance gain), especially when you optimize your prompt alongside.

Below I have plotted the savings assuming that our average document length is D tokens and our instructions and few-shot examples have r*D tokens. The example scenario from the previous paragraph where the instructions are half the length of the document (r = 0.5) appears in blue below. For longer shared instructions, our savings can be even higher:

The main takeaways are:

  • Even with relatively short instructions (blue line), there is value in minibatching
  • It’s not necessary to use really large minibatch sizes. Most savings can be obtained with even moderate minibatch sizes (B ≤ 10).

Let’s turn practical with a task where we want to categorize pieces of text for further analysis. We’ll use a fun task from the Natural-Instructions benchmark where we need to annotate sentences in debates with one of four categories (value, fact, testimony or policy).

Looking at an example, we see that we get the current topic for context and then need to categorize the sentence in question.

{
"input": {
"topic": "the fight for justice,equality,peaceand love is futile",
"sentence": "What matters is what I am personally doing to ensure that I am filling the cup!"
},
"output": "Value"
}

One question we haven’t answered yet:

How do we pick the right minibatch size?

Previous work has shown that the best minibatch size depends on the task as well as the model. We essentially have two options:

  1. We pick a reasonable minibatch size, let’s say 5 and hope that we don’t see any drops.
  2. We optimize the minibatch size along with other choices, e.g., the number of few-shot examples.

As you might have guessed, we’ll pursue option 2 here. To run our experiments, we’ll use SAMMO, a framework for LLM calling and prompt optimization.

Prompts are coded up in SAMMO as prompt programs (which are simply nested Python classes that’ll be called with input data). We’ll structure our task into three sections and format our minibatches in JSON format.

def prompt_program(fewshot_data, n_fewshot_examples=5, minibatch_size=1):
return Output(
MetaPrompt(
[
Section("Instructions", task["Definition"]),
Section(
"Examples",
FewshotExamples(
fewshot_data, n_fewshot_examples
),
),
Section("Output in same format as above", InputData()),
],
data_formatter=JSONDataFormatter(),
render_as="markdown",
).with_extractor(on_error="empty_result"),
minibatch_size=minibatch_size,
on_error="empty_result",
)

Running this without minibatching and using five few-shot examples, we get an accuracy of 0.76 and have to pay 58255 input tokens.

Let’s now explore how minibatching affects costs and performance. Since minibatching reduces the total input costs, we can now use some of those savings to add more few-shot examples! We can study those trade-offs by setting up a search space in SAMMO:

def search_space(fewshot_data):
minibatch_size = search_op.one_of([1, 5, 10], name="minibatch_size")
n_fewshot_examples = search_op.one_of([5, 20], name="n_fewshot")

return prompt_program(fewshot_data, n_fewshot_examples, minibatch_size)

Running this shows us the full gamut of trade-offs:

  setting                                  objective    costs                              parse_errors
--------------------------------------- ----------- --------------------------------- --------------
* {'minibatch_size': 1, 'n_fewshot': 5} 0.76 {'input': 58255, 'output': 5817} 0.0
{'minibatch_size': 1, 'n_fewshot': 20} 0.76 {'input': 133355, 'output': 6234} 0.0
{'minibatch_size': 5, 'n_fewshot': 5} 0.75 {'input': 15297, 'output': 5695} 0.0
{'minibatch_size': 5, 'n_fewshot': 20} 0.77 {'input': 30317, 'output': 5524} 0.0
{'minibatch_size': 10, 'n_fewshot': 5} 0.73 {'input': 9928, 'output': 5633} 0.0
* {'minibatch_size': 10, 'n_fewshot': 20} 0.77 {'input': 17438, 'output': 5432} 0.0

So, even with 20 few-shot examples, we save nearly 70 % input costs ([58255–17438]/58255) all while maintaining overall accuracy! As an exercise, you can implement your own objective to automatically factor in costs or include different ways of formatting the data in the search space.

Implicit in all of this is that (i) we have enough input examples that use the shared instructions and (ii) we have some flexibility regarding latency. The first assumption is met in many annotation scenarios, but obviously doesn’t hold in one-off queries. In annotation or other offline processing tasks, latency is also not super critical as throughput matters most. However, if your task is to provide a user with the answer as quickly as possible, it might make more sense to issue B parallel calls than one call with B input examples.

Related Updates

Recent Updates