Skip to content

batching

Batch text chunks together to fit within a max tokens per batch.

The max tokens per batch is expected to be a subset of the model's context window size.

Functions:

Name Description
batch_chunks_by_token_limit

Batch text chunks together to fit within a max tokens per batch.

batch_chunks_by_token_limit ๐Ÿ”—

batch_chunks_by_token_limit(chunks: Iterable[str], max_tokens_per_batch: TokenCount, joiner: str = '\n\n---\n\n') -> Iterator[str]

Batch text chunks together to fit within a max tokens per batch.

This function takes an iterable of text chunks and combines them into larger batches that will fit within the specified max tokens per batch. It uses token estimation to determine how many chunks can be combined safely.

Chunks are combined with a joiner (defined by the joiner parameter) to help distinguish between different chunks in the same batch.

This batching optimizes API usage by reducing the number of calls to the LLM provider, which can help avoid rate limiting errors for repositories with many commits or other text chunks.

The size of each chunk is overestimated to be conservative, that is, to ensure that the chunk will fit within the max tokens per batch. The size of each batch is estimated by summing the sizes of its chunks plus the joiner tokens, which is also an approximation.

Parameters:

Name Type Description Default

chunks ๐Ÿ”—

Iterable[str]

An iterable of strings, where each string is a chunk of text.

required

max_tokens_per_batch ๐Ÿ”—

TokenCount

The maximum number of tokens allowed per batch.

required

joiner ๐Ÿ”—

str

The string to use for joining chunks when batching them together.

'\n\n---\n\n'

Yields:

Type Description
str

Batches of text chunks, each fitting within the max tokens per batch.

Raises:

Type Description
ValueError

If max_tokens_per_batch is not positive.

Source code in src/brag/batching.py
def batch_chunks_by_token_limit(
    chunks: Iterable[str],
    max_tokens_per_batch: TokenCount,
    joiner: str = "\n\n---\n\n",
) -> Iterator[str]:
    """Batch text chunks together to fit within a max tokens per batch.

    This function takes an iterable of text chunks and combines them into larger
    batches that will fit within the specified max tokens per batch. It uses
    token estimation to determine how many chunks can be combined safely.

    Chunks are combined with a joiner (defined by the `joiner` parameter) to help
    distinguish between different chunks in the same batch.

    This batching optimizes API usage by reducing the number of calls to the LLM provider,
    which can help avoid rate limiting errors for repositories with many commits or
    other text chunks.

    The size of each chunk is overestimated to be conservative, that is, to ensure that the chunk
    will fit within the max tokens per batch. The size of each batch is estimated by summing the
    sizes of its chunks plus the joiner tokens, which is also an approximation.

    Args:
        chunks: An iterable of strings, where each string is a chunk of text.
        max_tokens_per_batch: The maximum number of tokens allowed per batch.
        joiner: The string to use for joining chunks when batching them together.

    Yields:
        Batches of text chunks, each fitting within the max tokens per batch.

    Raises:
        ValueError: If max_tokens_per_batch is not positive.
    """
    if max_tokens_per_batch <= 0:
        raise ValueError("max_tokens_per_batch must be positive")

    joiner_token_count = estimate_token_count(joiner, approximation_mode="overestimate")
    current_batch: str = ""
    current_batch_token_count = 0

    for chunk in chunks:
        # Use the overestimate strategy to be conservative, that is, to ensure that the chunk
        # will fit within the max tokens per batch.
        chunk_token_count = estimate_token_count(
            chunk, approximation_mode="overestimate"
        )

        # Only add joiner_token_count if current_batch is not empty
        additional_token_count = joiner_token_count if current_batch else 0

        # If adding this chunk would exceed the limit, yield the current batch and start a new one
        if current_batch and (
            current_batch_token_count + additional_token_count + chunk_token_count
            > max_tokens_per_batch
        ):
            yield current_batch
            current_batch = chunk
            current_batch_token_count = chunk_token_count
        else:
            current_batch = promptify(current_batch, chunk, joiner=joiner)
            current_batch_token_count += additional_token_count + chunk_token_count

    # Add the last batch if it's not empty
    if current_batch:
        yield current_batch