Skip to content

llm

AsyncVirtualLM

Bases: AsyncLM

Source code in genlm/backend/llm/vllm.py
class AsyncVirtualLM(AsyncLM):
    default_params = {
        "max_tokens": 1,
        "n": 1,
        "detokenize": False,
        "stop": None,
        "ignore_eos": True,
    }

    def __init__(self, async_llm_engine, cache_size=0, cache_opts={}):
        """Initialize an `AsyncVirtualLM` instance.

        Args:
            async_llm_engine (AsyncLLMEngine): The async vLLM engine instance.
            cache_size (int, optional): Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.
            cache_opts (dict, optional): Additional options to pass to the [`OutputCache`][genlm.backend.cache.OutputCache] constructor. Defaults to {}.

        Note:
            The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine.
        """
        self.async_llm_engine = async_llm_engine
        self.tokenizer = async_llm_engine.engine.get_tokenizer()
        self.request_counter = Counter()
        self.cache = (
            OutputCache(maxsize=cache_size, **cache_opts)
            if cache_size > 0
            else None
        )

        async_llm_engine.engine.log_stats = False

        super().__init__(tokenizer=self.tokenizer)

    @classmethod
    def from_name(cls, model_name, engine_opts=None, **kwargs):
        """Create a `AsyncVirtualLM` instance from a model name.

        Args:
            model_name (str): Name of the model to load.
            engine_opts (dict): Additional options to pass to the `AsyncLLMEngine`. The engine will be
                configured with prefix caching enabled and async output processing disabled by default.
            **kwargs: Additional arguments passed to `AsyncVirtualLM` constructor.

        Returns:
            (AsyncVirtualLM): An `AsyncVirtualLM` instance.
        """
        if not HAS_VLLM:
            raise ImportError(  # pragma: no cover
                "vLLM not available. Install vLLM or use AsyncTransformer instead."
            )

        if engine_opts is not None and "enable_chunked_prefill" in engine_opts:
            if engine_opts["enable_chunked_prefill"]:
                warnings.warn(  # pragma: no cover
                    "Setting enable_chunked_prefill to True may interfere with AsyncVirtualLM's "
                    "custom sampling functionality."
                )

        engine_opts = {
            "enable_prefix_caching": True,
            "disable_log_requests": True,
            "disable_async_output_proc": True,  # This parameter forces vLLM to use v0, which is currently what we want to do.
            **(engine_opts or {}),
        }

        engine = AsyncLLMEngine.from_engine_args(
            AsyncEngineArgs(model=model_name, tokenizer=model_name, **engine_opts)
        )

        return cls(engine, **kwargs)

    @property
    def underlying_model(self):
        return self.async_llm_engine.engine.model_executor.driver_worker.model_runner.model

    async def next_token_logprobs(self, token_ids):
        """Request log probabilities of next token asynchronously with output caching.

        Args:
            token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model.

        Returns:
            result (torch.Tensor): Normalized log probability tensor.

        Warning:
            Do not use `asyncio.run(next_token_logprobs())` as it may interfere with vLLM's background loop.
            For synchronous usage, use the `next_token_logprobs_sync()` method instead.
        """
        key = tuple(token_ids)

        if self.cache is not None and key in self.cache:
            return self.cache[key]

        result = await self._next_token_logprobs(key)

        if self.cache is not None:
            self.cache[key] = result

        return result

    async def _next_token_logprobs(self, token_ids):
        """Request log probabilities of next token asynchronously.

        Args:
            token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        req_id = str(next(self.request_counter))
        prompt = TokensPrompt(prompt_token_ids=token_ids)

        outputs = []
        processor = PassThroughLogitsProcessor()
        async for output in self.async_llm_engine.generate(
            prompt=prompt,
            sampling_params=SamplingParams(
                **self.default_params, logits_processors=[processor]
            ),
            request_id=req_id,
        ):
            if output.finished:
                outputs.append(output)

        assert processor.log_probs is not None, (
            "Log probs should be set by the logits processor."
        )
        return processor.log_probs

    def next_token_logprobs_sync(self, token_ids):
        """Request log probabilities of next token synchronously.

        Args:
            token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        return self.batch_next_token_logprobs_sync([token_ids])[0]

    def batch_next_token_logprobs_sync(self, token_ids_list):
        """
        Request log probabilities of next tokens in a batch synchronously.

        Args:
            token_ids_list (list[list[int]]): A list of token ID lists, each representing a prompt to the language model.

        Returns:
            (torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list.
        """
        req_ids = []
        req_id2processors = {}
        for token_ids in token_ids_list:
            req_id = str(next(self.request_counter))
            req_ids.append(req_id)
            processor = PassThroughLogitsProcessor()
            req_id2processors[req_id] = processor
            self.async_llm_engine.engine.add_request(
                prompt=TokensPrompt(prompt_token_ids=token_ids),
                params=SamplingParams(
                    **self.default_params, logits_processors=[processor]
                ),
                request_id=req_id,
            )

        while self.async_llm_engine.engine.has_unfinished_requests():
            output = self.async_llm_engine.engine.step()
            for out in output:
                if out.finished:
                    assert out.request_id in req_id2processors, (
                        f"{out.request_id} not in requested IDs"
                    )

        return torch.stack(
            [req_id2processors[req_id].log_probs for req_id in req_ids]
        )

    def clear_cache(self):
        """Clear output cache."""
        if self.cache:
            self.cache.clear()

    def __del__(self):
        """Clean up resources on deletion."""
        self._cleanup_engine()

    def _cleanup_engine(self):
        """Clean up the vLLM engine and associated resources."""
        if async_engine := getattr(self, "async_llm_engine", None):
            async_engine.shutdown_background_loop()
            destroy_model_parallel()
            destroy_distributed_environment()

    async def sample(
        self,
        prompt_token_ids,
        max_tokens,
        eos_token_ids,
        temperature=1.0,
        seed=None,
    ):
        """Sample from the language model.

        Args:
            prompt_token_ids (list[int]): The token IDs of the prompt.
            eos_token_ids (list[int]): The token IDs of the end-of-sequence tokens.
            temperature (float, optional): The temperature to use to rescale the logits. Defaults to 1.0.
            max_tokens (int): The maximum number of tokens to generate.
            seed (int, optional): The seed for the random number generator. Defaults to None.

        Returns:
            (list[int]): The sampled token IDs.
        """
        async for output in self.async_llm_engine.generate(
            prompt=TokensPrompt(prompt_token_ids=prompt_token_ids),
            sampling_params=SamplingParams(
                n=1,
                max_tokens=max_tokens,
                temperature=temperature,
                seed=seed,
                stop=[self.byte_vocab[i].decode() for i in eos_token_ids],
            ),
            request_id=str(next(self.request_counter)),
        ):
            if output.finished:
                assert len(output.outputs) == 1, (
                    "Expected exactly one sequence group"
                )
                token_ids = list(output.outputs[0].token_ids)
                if token_ids[-1] in eos_token_ids:
                    token_ids = token_ids[:-1]
                return token_ids

__init__(async_llm_engine, cache_size=0, cache_opts={})

Initialize an AsyncVirtualLM instance.

Parameters:

Name Type Description Default
async_llm_engine AsyncLLMEngine

The async vLLM engine instance.

required
cache_size int

Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.

0
cache_opts dict

Additional options to pass to the OutputCache constructor. Defaults to {}.

{}
Note

The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine.

Source code in genlm/backend/llm/vllm.py
def __init__(self, async_llm_engine, cache_size=0, cache_opts={}):
    """Initialize an `AsyncVirtualLM` instance.

    Args:
        async_llm_engine (AsyncLLMEngine): The async vLLM engine instance.
        cache_size (int, optional): Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.
        cache_opts (dict, optional): Additional options to pass to the [`OutputCache`][genlm.backend.cache.OutputCache] constructor. Defaults to {}.

    Note:
        The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine.
    """
    self.async_llm_engine = async_llm_engine
    self.tokenizer = async_llm_engine.engine.get_tokenizer()
    self.request_counter = Counter()
    self.cache = (
        OutputCache(maxsize=cache_size, **cache_opts)
        if cache_size > 0
        else None
    )

    async_llm_engine.engine.log_stats = False

    super().__init__(tokenizer=self.tokenizer)

from_name(model_name, engine_opts=None, **kwargs) classmethod

Create a AsyncVirtualLM instance from a model name.

Parameters:

Name Type Description Default
model_name str

Name of the model to load.

required
engine_opts dict

Additional options to pass to the AsyncLLMEngine. The engine will be configured with prefix caching enabled and async output processing disabled by default.

None
**kwargs

Additional arguments passed to AsyncVirtualLM constructor.

{}

Returns:

Type Description
AsyncVirtualLM

An AsyncVirtualLM instance.

Source code in genlm/backend/llm/vllm.py
@classmethod
def from_name(cls, model_name, engine_opts=None, **kwargs):
    """Create a `AsyncVirtualLM` instance from a model name.

    Args:
        model_name (str): Name of the model to load.
        engine_opts (dict): Additional options to pass to the `AsyncLLMEngine`. The engine will be
            configured with prefix caching enabled and async output processing disabled by default.
        **kwargs: Additional arguments passed to `AsyncVirtualLM` constructor.

    Returns:
        (AsyncVirtualLM): An `AsyncVirtualLM` instance.
    """
    if not HAS_VLLM:
        raise ImportError(  # pragma: no cover
            "vLLM not available. Install vLLM or use AsyncTransformer instead."
        )

    if engine_opts is not None and "enable_chunked_prefill" in engine_opts:
        if engine_opts["enable_chunked_prefill"]:
            warnings.warn(  # pragma: no cover
                "Setting enable_chunked_prefill to True may interfere with AsyncVirtualLM's "
                "custom sampling functionality."
            )

    engine_opts = {
        "enable_prefix_caching": True,
        "disable_log_requests": True,
        "disable_async_output_proc": True,  # This parameter forces vLLM to use v0, which is currently what we want to do.
        **(engine_opts or {}),
    }

    engine = AsyncLLMEngine.from_engine_args(
        AsyncEngineArgs(model=model_name, tokenizer=model_name, **engine_opts)
    )

    return cls(engine, **kwargs)

next_token_logprobs(token_ids) async

Request log probabilities of next token asynchronously with output caching.

Parameters:

Name Type Description Default
token_ids_list list[int]

A list of token IDs, representing a prompt to the language model.

required

Returns:

Name Type Description
result Tensor

Normalized log probability tensor.

Warning

Do not use asyncio.run(next_token_logprobs()) as it may interfere with vLLM's background loop. For synchronous usage, use the next_token_logprobs_sync() method instead.

Source code in genlm/backend/llm/vllm.py
async def next_token_logprobs(self, token_ids):
    """Request log probabilities of next token asynchronously with output caching.

    Args:
        token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model.

    Returns:
        result (torch.Tensor): Normalized log probability tensor.

    Warning:
        Do not use `asyncio.run(next_token_logprobs())` as it may interfere with vLLM's background loop.
        For synchronous usage, use the `next_token_logprobs_sync()` method instead.
    """
    key = tuple(token_ids)

    if self.cache is not None and key in self.cache:
        return self.cache[key]

    result = await self._next_token_logprobs(key)

    if self.cache is not None:
        self.cache[key] = result

    return result

next_token_logprobs_sync(token_ids)

Request log probabilities of next token synchronously.

Parameters:

Name Type Description Default
token_ids_list list[int]

A list of token IDs, representing a prompt to the language model.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm/backend/llm/vllm.py
def next_token_logprobs_sync(self, token_ids):
    """Request log probabilities of next token synchronously.

    Args:
        token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    return self.batch_next_token_logprobs_sync([token_ids])[0]

batch_next_token_logprobs_sync(token_ids_list)

Request log probabilities of next tokens in a batch synchronously.

Parameters:

Name Type Description Default
token_ids_list list[list[int]]

A list of token ID lists, each representing a prompt to the language model.

required

Returns:

Type Description
Tensor

A tensor of normalized log probability tensors, one for each prompt in the input list.

Source code in genlm/backend/llm/vllm.py
def batch_next_token_logprobs_sync(self, token_ids_list):
    """
    Request log probabilities of next tokens in a batch synchronously.

    Args:
        token_ids_list (list[list[int]]): A list of token ID lists, each representing a prompt to the language model.

    Returns:
        (torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list.
    """
    req_ids = []
    req_id2processors = {}
    for token_ids in token_ids_list:
        req_id = str(next(self.request_counter))
        req_ids.append(req_id)
        processor = PassThroughLogitsProcessor()
        req_id2processors[req_id] = processor
        self.async_llm_engine.engine.add_request(
            prompt=TokensPrompt(prompt_token_ids=token_ids),
            params=SamplingParams(
                **self.default_params, logits_processors=[processor]
            ),
            request_id=req_id,
        )

    while self.async_llm_engine.engine.has_unfinished_requests():
        output = self.async_llm_engine.engine.step()
        for out in output:
            if out.finished:
                assert out.request_id in req_id2processors, (
                    f"{out.request_id} not in requested IDs"
                )

    return torch.stack(
        [req_id2processors[req_id].log_probs for req_id in req_ids]
    )

clear_cache()

Clear output cache.

Source code in genlm/backend/llm/vllm.py
def clear_cache(self):
    """Clear output cache."""
    if self.cache:
        self.cache.clear()

__del__()

Clean up resources on deletion.

Source code in genlm/backend/llm/vllm.py
def __del__(self):
    """Clean up resources on deletion."""
    self._cleanup_engine()

sample(prompt_token_ids, max_tokens, eos_token_ids, temperature=1.0, seed=None) async

Sample from the language model.

Parameters:

Name Type Description Default
prompt_token_ids list[int]

The token IDs of the prompt.

required
eos_token_ids list[int]

The token IDs of the end-of-sequence tokens.

required
temperature float

The temperature to use to rescale the logits. Defaults to 1.0.

1.0
max_tokens int

The maximum number of tokens to generate.

required
seed int

The seed for the random number generator. Defaults to None.

None

Returns:

Type Description
list[int]

The sampled token IDs.

Source code in genlm/backend/llm/vllm.py
async def sample(
    self,
    prompt_token_ids,
    max_tokens,
    eos_token_ids,
    temperature=1.0,
    seed=None,
):
    """Sample from the language model.

    Args:
        prompt_token_ids (list[int]): The token IDs of the prompt.
        eos_token_ids (list[int]): The token IDs of the end-of-sequence tokens.
        temperature (float, optional): The temperature to use to rescale the logits. Defaults to 1.0.
        max_tokens (int): The maximum number of tokens to generate.
        seed (int, optional): The seed for the random number generator. Defaults to None.

    Returns:
        (list[int]): The sampled token IDs.
    """
    async for output in self.async_llm_engine.generate(
        prompt=TokensPrompt(prompt_token_ids=prompt_token_ids),
        sampling_params=SamplingParams(
            n=1,
            max_tokens=max_tokens,
            temperature=temperature,
            seed=seed,
            stop=[self.byte_vocab[i].decode() for i in eos_token_ids],
        ),
        request_id=str(next(self.request_counter)),
    ):
        if output.finished:
            assert len(output.outputs) == 1, (
                "Expected exactly one sequence group"
            )
            token_ids = list(output.outputs[0].token_ids)
            if token_ids[-1] in eos_token_ids:
                token_ids = token_ids[:-1]
            return token_ids

AsyncTransformer

Bases: AsyncLM

Asynchronous wrapper around a HuggingFace causal language model with caching support.

This class provides an asynchronous interface to HuggingFace language models with automatic batching and caching (output and KV) for improved efficiency.

Source code in genlm/backend/llm/hf.py
class AsyncTransformer(AsyncLM):
    """Asynchronous wrapper around a HuggingFace causal language model with caching support.

    This class provides an asynchronous interface to HuggingFace language models with automatic batching
    and caching (output and KV) for improved efficiency.
    """

    @classmethod
    def from_name(cls, model_id, bitsandbytes_opts=None, hf_opts=None, **kwargs):
        """Create an AsyncTransformer instance from a pretrained HuggingFace model.

        Args:
            model_id (str): Model identifier in HuggingFace's model hub.
            bitsandbytes_opts (dict, optional): Additional configuration options for bitsandbytes quantization.
                Defaults to None.
            hf_opts (dict, optional): Additional configuration options for loading the HuggingFace model.
                Defaults to None.
            **kwargs: Additional arguments passed to the `AsyncTransformer` constructor

        Returns:
            (AsyncTransformer): An initialized `AsyncTransformer` instance.
        """
        if bitsandbytes_opts:
            bnb_config = BitsAndBytesConfig(**bitsandbytes_opts)
        else:
            bnb_config = None

        _hf_opts = {
            "device_map": "auto",
            "torch_dtype": "auto",
        }
        if hf_opts:
            _hf_opts.update(hf_opts)

        tok = AutoTokenizer.from_pretrained(model_id)
        mod = AutoModelForCausalLM.from_pretrained(
            model_id, quantization_config=bnb_config, **_hf_opts
        )

        return cls(mod, tok, **kwargs)

    @torch.no_grad()
    def __init__(self, hf_model, hf_tokenizer, batch_size=20, timeout=0.02):
        """Initialize an AsyncTransformer instance.

        Args:
            hf_model: A HuggingFace CausalLM model instance.
            hf_tokenizer: A HuggingFace Tokenizer.
            batch_size (int, optional): Maximum queries to process in one batch during auto-batching.
                Defaults to 20.
            timeout (float, optional): Seconds to wait since last query before processing current batch.
                Defaults to 0.02.
        """
        self.model = hf_model
        self.tokenizer = hf_tokenizer
        self.device = hf_model.device
        self.cache = TokenTrie()

        # Queries to be batched. Each query is a sequence of tokens,
        # and a Future to be called when the query is resolved.
        self.queries = []
        self.batch_size = batch_size
        self.timeout = timeout
        self.timer = None

        self.model.eval()

        super().__init__(tokenizer=self.tokenizer)

    def clear_cache(self):
        """Clear the cache of log probabilities and key/value pairs."""
        self.cache = TokenTrie()

    def clear_kv_cache(self):
        """Clear any key and value vectors from the cache."""
        self.cache.clear_kv_cache()

    def reset_async_queries(self):
        """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing
        to completion."""
        self.queries = []

    @torch.no_grad()
    def cache_kv(self, prompt_tokens):
        """Cache the key and value vectors for a prompt. Future queries that have this prompt as a prefix will only run the LLM on new tokens.

        Args:
            prompt_tokens (list[int]): token ids for the prompt to cache.
        """
        result = self.model(torch.tensor([prompt_tokens]).to(self.device))
        node = self.cache.extend_cache(0, prompt_tokens, result.logits[0], 0)
        node.past_key_values = result.past_key_values

    @torch.no_grad()
    def batch_evaluate_queries(self):
        """
        Process a batch of queued language model queries.

        This method is called internally when the `batch_size` has been met or the `timeout` has expired.
        """

        queries, self.queries = self.queries, []
        if len(queries) == 0:
            return

        query_groups = defaultdict(list)
        for query in queries:
            key = tuple(query.prompt)  # XXX: cache based on past_len too?
            query_groups[key].append(query)

        # Use one representative query from each group
        unique_queries = [group[0] for group in query_groups.values()]

        past_example = next((q.past for q in unique_queries if q.past), False)
        max_past_length = max(q.past_len for q in unique_queries)
        max_query_length = max(len(q.prompt) for q in unique_queries)

        padding_token_id = (
            self.tokenizer.pad_token_id
            if self.tokenizer.pad_token_id is not None
            else 0
        )

        input_ids = torch.tensor(
            [
                q.prompt_padded(padding_token_id, max_query_length)
                for q in unique_queries
            ]
        ).to(self.device)
        attn_masks = torch.tensor(
            [
                q.attention_mask(max_past_length, max_query_length)
                for q in unique_queries
            ]
        ).to(self.device)
        posn_ids = torch.tensor(
            [q.position_ids(max_past_length, max_query_length) for q in unique_queries]
        ).to(self.device)
        if past_example:
            pasts = [
                [
                    torch.cat(
                        (
                            *(
                                q.past_padded(
                                    layer,
                                    j,
                                    max_past_length,
                                    past_example[0][0].dtype,
                                    self.device,
                                    past_example[0][0].shape,
                                )
                                for q in unique_queries
                            ),
                        ),
                        dim=0,
                    )
                    for j in range(2)
                ]
                for layer in range(len(past_example))
            ]
        else:
            pasts = None

        pasts = DynamicCache.from_legacy_cache(pasts)

        results = self.model(
            input_ids,
            attention_mask=attn_masks,
            position_ids=posn_ids,
            past_key_values=pasts,
            use_cache=pasts is not None,
        )

        assert len(results.logits) == len(unique_queries)

        for i, q in enumerate(unique_queries):
            result = results.logits[i]
            for dup_query in query_groups[tuple(q.prompt)]:
                dup_query.future.set_result(result)

    @torch.no_grad()
    def add_query(self, query, future, past):
        """Add a query to be evaluated in the next batch.

        This method is called internally when a `next_token_logprobs` request is made.

        Args:
            query (list[int]): Token IDs representing the query prompt
            future (asyncio.Future): Future to store the result in
            past (list[tuple[torch.Tensor]]|None): Past key/value states from previous evaluation,
                or None if this is a new query
        """
        self.queries.append(Query(query, future, past))

        if self.timer:
            self.timer.cancel()
            self.timer = None
        if len(self.queries) >= self.batch_size:
            self.batch_evaluate_queries()
        else:
            self.timer = asyncio.get_running_loop().call_later(
                self.timeout, lambda: self.batch_evaluate_queries()
            )

    def walk_cache(self, token_ids):
        """Walk the cache tree to find the deepest node matching a sequence of tokens.

        Args:
            token_ids (list[int]): Sequence of token IDs to follow in the cache tree

        Returns:
            tuple:
                - CacheNode: The deepest node in the cache tree that matches the token sequence
                - int: Number of tokens matched from the start of token_ids
                - list[tuple[torch.Tensor]]|None: Past key/value states from the deepest cached node,
                    or None if no cached states were found
                - int: Base index indicating where the past states start in token_ids
        """
        # Walk while tokens can be found
        node = self.cache
        next_token_index = 0

        past = None
        base = 0
        while next_token_index < len(token_ids):
            if node.past_key_values is not None:
                past = node.past_key_values
                base = next_token_index
            if node.has_token(token_ids[next_token_index]):
                node = node.get_token(token_ids[next_token_index])
                next_token_index += 1
            else:
                break

        return node, next_token_index, past, base

    @torch.no_grad()
    async def next_token_logprobs(self, token_ids):
        """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`.

        Args:
            token_ids (list[int]): a list of token ids, representing a prompt to the language model.

        Returns:
            logprobs (torch.Tensor): a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        node, next_token_index, past, base = self.walk_cache(token_ids)

        # If we processed all tokens, then we're done.
        if next_token_index == len(token_ids):
            return node.logprobs

        # Create a future with the prompt
        future = asyncio.get_running_loop().create_future()
        self.add_query(token_ids[base:], future, past)
        logits = await future

        # Create new nodes
        node = node.extend_cache(next_token_index, token_ids, logits, base)

        return node.logprobs

    @torch.no_grad()
    def next_token_logprobs_sync(self, token_ids):
        """Request log probabilities of next token. Not asynchronous, and does not support auto-batching.

        Args:
            token_ids (list[int]): a list of token ids, representing a prompt to the language model.

        Returns:
            logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        # Walk while tokens can be found
        node, next_token_index, past, base = self.walk_cache(token_ids)

        if next_token_index == len(token_ids):
            return node.logprobs

        logits = self.model(
            torch.tensor([token_ids[base:]]).to(self.device),
            past_key_values=node.past_key_values,
            use_cache=node.past_key_values is not None,
        ).logits[0]

        node = node.extend_cache(next_token_index, token_ids, logits, base)

        return node.logprobs

    def next_token_logprobs_uncached(self, token_ids):
        """Request log probabilities of next token. No KV or output caching, and does not support auto-batching.

        Args:
            token_ids (list[int]): a list of token ids, representing a prompt to the language model.

        Returns:
            logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        with torch.no_grad():
            logits = self.model(
                torch.tensor([token_ids]).to(self.device),
                past_key_values=None,
                use_cache=False,
            ).logits[0]
            return torch.log_softmax(logits[-1], dim=0)

from_name(model_id, bitsandbytes_opts=None, hf_opts=None, **kwargs) classmethod

Create an AsyncTransformer instance from a pretrained HuggingFace model.

Parameters:

Name Type Description Default
model_id str

Model identifier in HuggingFace's model hub.

required
bitsandbytes_opts dict

Additional configuration options for bitsandbytes quantization. Defaults to None.

None
hf_opts dict

Additional configuration options for loading the HuggingFace model. Defaults to None.

None
**kwargs

Additional arguments passed to the AsyncTransformer constructor

{}

Returns:

Type Description
AsyncTransformer

An initialized AsyncTransformer instance.

Source code in genlm/backend/llm/hf.py
@classmethod
def from_name(cls, model_id, bitsandbytes_opts=None, hf_opts=None, **kwargs):
    """Create an AsyncTransformer instance from a pretrained HuggingFace model.

    Args:
        model_id (str): Model identifier in HuggingFace's model hub.
        bitsandbytes_opts (dict, optional): Additional configuration options for bitsandbytes quantization.
            Defaults to None.
        hf_opts (dict, optional): Additional configuration options for loading the HuggingFace model.
            Defaults to None.
        **kwargs: Additional arguments passed to the `AsyncTransformer` constructor

    Returns:
        (AsyncTransformer): An initialized `AsyncTransformer` instance.
    """
    if bitsandbytes_opts:
        bnb_config = BitsAndBytesConfig(**bitsandbytes_opts)
    else:
        bnb_config = None

    _hf_opts = {
        "device_map": "auto",
        "torch_dtype": "auto",
    }
    if hf_opts:
        _hf_opts.update(hf_opts)

    tok = AutoTokenizer.from_pretrained(model_id)
    mod = AutoModelForCausalLM.from_pretrained(
        model_id, quantization_config=bnb_config, **_hf_opts
    )

    return cls(mod, tok, **kwargs)

__init__(hf_model, hf_tokenizer, batch_size=20, timeout=0.02)

Initialize an AsyncTransformer instance.

Parameters:

Name Type Description Default
hf_model

A HuggingFace CausalLM model instance.

required
hf_tokenizer

A HuggingFace Tokenizer.

required
batch_size int

Maximum queries to process in one batch during auto-batching. Defaults to 20.

20
timeout float

Seconds to wait since last query before processing current batch. Defaults to 0.02.

0.02
Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def __init__(self, hf_model, hf_tokenizer, batch_size=20, timeout=0.02):
    """Initialize an AsyncTransformer instance.

    Args:
        hf_model: A HuggingFace CausalLM model instance.
        hf_tokenizer: A HuggingFace Tokenizer.
        batch_size (int, optional): Maximum queries to process in one batch during auto-batching.
            Defaults to 20.
        timeout (float, optional): Seconds to wait since last query before processing current batch.
            Defaults to 0.02.
    """
    self.model = hf_model
    self.tokenizer = hf_tokenizer
    self.device = hf_model.device
    self.cache = TokenTrie()

    # Queries to be batched. Each query is a sequence of tokens,
    # and a Future to be called when the query is resolved.
    self.queries = []
    self.batch_size = batch_size
    self.timeout = timeout
    self.timer = None

    self.model.eval()

    super().__init__(tokenizer=self.tokenizer)

clear_cache()

Clear the cache of log probabilities and key/value pairs.

Source code in genlm/backend/llm/hf.py
def clear_cache(self):
    """Clear the cache of log probabilities and key/value pairs."""
    self.cache = TokenTrie()

clear_kv_cache()

Clear any key and value vectors from the cache.

Source code in genlm/backend/llm/hf.py
def clear_kv_cache(self):
    """Clear any key and value vectors from the cache."""
    self.cache.clear_kv_cache()

reset_async_queries()

Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing to completion.

Source code in genlm/backend/llm/hf.py
def reset_async_queries(self):
    """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing
    to completion."""
    self.queries = []

cache_kv(prompt_tokens)

Cache the key and value vectors for a prompt. Future queries that have this prompt as a prefix will only run the LLM on new tokens.

Parameters:

Name Type Description Default
prompt_tokens list[int]

token ids for the prompt to cache.

required
Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def cache_kv(self, prompt_tokens):
    """Cache the key and value vectors for a prompt. Future queries that have this prompt as a prefix will only run the LLM on new tokens.

    Args:
        prompt_tokens (list[int]): token ids for the prompt to cache.
    """
    result = self.model(torch.tensor([prompt_tokens]).to(self.device))
    node = self.cache.extend_cache(0, prompt_tokens, result.logits[0], 0)
    node.past_key_values = result.past_key_values

batch_evaluate_queries()

Process a batch of queued language model queries.

This method is called internally when the batch_size has been met or the timeout has expired.

Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def batch_evaluate_queries(self):
    """
    Process a batch of queued language model queries.

    This method is called internally when the `batch_size` has been met or the `timeout` has expired.
    """

    queries, self.queries = self.queries, []
    if len(queries) == 0:
        return

    query_groups = defaultdict(list)
    for query in queries:
        key = tuple(query.prompt)  # XXX: cache based on past_len too?
        query_groups[key].append(query)

    # Use one representative query from each group
    unique_queries = [group[0] for group in query_groups.values()]

    past_example = next((q.past for q in unique_queries if q.past), False)
    max_past_length = max(q.past_len for q in unique_queries)
    max_query_length = max(len(q.prompt) for q in unique_queries)

    padding_token_id = (
        self.tokenizer.pad_token_id
        if self.tokenizer.pad_token_id is not None
        else 0
    )

    input_ids = torch.tensor(
        [
            q.prompt_padded(padding_token_id, max_query_length)
            for q in unique_queries
        ]
    ).to(self.device)
    attn_masks = torch.tensor(
        [
            q.attention_mask(max_past_length, max_query_length)
            for q in unique_queries
        ]
    ).to(self.device)
    posn_ids = torch.tensor(
        [q.position_ids(max_past_length, max_query_length) for q in unique_queries]
    ).to(self.device)
    if past_example:
        pasts = [
            [
                torch.cat(
                    (
                        *(
                            q.past_padded(
                                layer,
                                j,
                                max_past_length,
                                past_example[0][0].dtype,
                                self.device,
                                past_example[0][0].shape,
                            )
                            for q in unique_queries
                        ),
                    ),
                    dim=0,
                )
                for j in range(2)
            ]
            for layer in range(len(past_example))
        ]
    else:
        pasts = None

    pasts = DynamicCache.from_legacy_cache(pasts)

    results = self.model(
        input_ids,
        attention_mask=attn_masks,
        position_ids=posn_ids,
        past_key_values=pasts,
        use_cache=pasts is not None,
    )

    assert len(results.logits) == len(unique_queries)

    for i, q in enumerate(unique_queries):
        result = results.logits[i]
        for dup_query in query_groups[tuple(q.prompt)]:
            dup_query.future.set_result(result)

add_query(query, future, past)

Add a query to be evaluated in the next batch.

This method is called internally when a next_token_logprobs request is made.

Parameters:

Name Type Description Default
query list[int]

Token IDs representing the query prompt

required
future Future

Future to store the result in

required
past list[tuple[Tensor]] | None

Past key/value states from previous evaluation, or None if this is a new query

required
Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def add_query(self, query, future, past):
    """Add a query to be evaluated in the next batch.

    This method is called internally when a `next_token_logprobs` request is made.

    Args:
        query (list[int]): Token IDs representing the query prompt
        future (asyncio.Future): Future to store the result in
        past (list[tuple[torch.Tensor]]|None): Past key/value states from previous evaluation,
            or None if this is a new query
    """
    self.queries.append(Query(query, future, past))

    if self.timer:
        self.timer.cancel()
        self.timer = None
    if len(self.queries) >= self.batch_size:
        self.batch_evaluate_queries()
    else:
        self.timer = asyncio.get_running_loop().call_later(
            self.timeout, lambda: self.batch_evaluate_queries()
        )

walk_cache(token_ids)

Walk the cache tree to find the deepest node matching a sequence of tokens.

Parameters:

Name Type Description Default
token_ids list[int]

Sequence of token IDs to follow in the cache tree

required

Returns:

Name Type Description
tuple
  • CacheNode: The deepest node in the cache tree that matches the token sequence
  • int: Number of tokens matched from the start of token_ids
  • list[tuple[torch.Tensor]]|None: Past key/value states from the deepest cached node, or None if no cached states were found
  • int: Base index indicating where the past states start in token_ids
Source code in genlm/backend/llm/hf.py
def walk_cache(self, token_ids):
    """Walk the cache tree to find the deepest node matching a sequence of tokens.

    Args:
        token_ids (list[int]): Sequence of token IDs to follow in the cache tree

    Returns:
        tuple:
            - CacheNode: The deepest node in the cache tree that matches the token sequence
            - int: Number of tokens matched from the start of token_ids
            - list[tuple[torch.Tensor]]|None: Past key/value states from the deepest cached node,
                or None if no cached states were found
            - int: Base index indicating where the past states start in token_ids
    """
    # Walk while tokens can be found
    node = self.cache
    next_token_index = 0

    past = None
    base = 0
    while next_token_index < len(token_ids):
        if node.past_key_values is not None:
            past = node.past_key_values
            base = next_token_index
        if node.has_token(token_ids[next_token_index]):
            node = node.get_token(token_ids[next_token_index])
            next_token_index += 1
        else:
            break

    return node, next_token_index, past, base

next_token_logprobs(token_ids) async

Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with await.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs Tensor

a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in genlm/backend/llm/hf.py
@torch.no_grad()
async def next_token_logprobs(self, token_ids):
    """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`.

    Args:
        token_ids (list[int]): a list of token ids, representing a prompt to the language model.

    Returns:
        logprobs (torch.Tensor): a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    node, next_token_index, past, base = self.walk_cache(token_ids)

    # If we processed all tokens, then we're done.
    if next_token_index == len(token_ids):
        return node.logprobs

    # Create a future with the prompt
    future = asyncio.get_running_loop().create_future()
    self.add_query(token_ids[base:], future, past)
    logits = await future

    # Create new nodes
    node = node.extend_cache(next_token_index, token_ids, logits, base)

    return node.logprobs

next_token_logprobs_sync(token_ids)

Request log probabilities of next token. Not asynchronous, and does not support auto-batching.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs Tensor

a tensor with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def next_token_logprobs_sync(self, token_ids):
    """Request log probabilities of next token. Not asynchronous, and does not support auto-batching.

    Args:
        token_ids (list[int]): a list of token ids, representing a prompt to the language model.

    Returns:
        logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    # Walk while tokens can be found
    node, next_token_index, past, base = self.walk_cache(token_ids)

    if next_token_index == len(token_ids):
        return node.logprobs

    logits = self.model(
        torch.tensor([token_ids[base:]]).to(self.device),
        past_key_values=node.past_key_values,
        use_cache=node.past_key_values is not None,
    ).logits[0]

    node = node.extend_cache(next_token_index, token_ids, logits, base)

    return node.logprobs

next_token_logprobs_uncached(token_ids)

Request log probabilities of next token. No KV or output caching, and does not support auto-batching.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs Tensor

a tensor with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in genlm/backend/llm/hf.py
def next_token_logprobs_uncached(self, token_ids):
    """Request log probabilities of next token. No KV or output caching, and does not support auto-batching.

    Args:
        token_ids (list[int]): a list of token ids, representing a prompt to the language model.

    Returns:
        logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    with torch.no_grad():
        logits = self.model(
            torch.tensor([token_ids]).to(self.device),
            past_key_values=None,
            use_cache=False,
        ).logits[0]
        return torch.log_softmax(logits[-1], dim=0)

AsyncLM

Bases: ABC

Abstract base class for asynchronous language models.

This class provides an interface for language models that can generate token probabilities asynchronously. It handles tokenization and vocabulary management.

Parameters:

Name Type Description Default
tokenizer

A Hugging Face tokenizer instance compatible with the language model

required
Source code in genlm/backend/llm/base.py
class AsyncLM(ABC):
    """Abstract base class for asynchronous language models.

    This class provides an interface for language models that can generate token probabilities
    asynchronously. It handles tokenization and vocabulary management.

    Args:
        tokenizer: A Hugging Face tokenizer instance compatible with the language model
    """

    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.byte_vocab, self.str_vocab = decode_vocab(self.tokenizer)

    @abstractmethod
    async def next_token_logprobs(self, token_ids):
        """Request log probabilities of next token asynchronously.

        Args:
            token_ids (list[int]): A list of token IDs representing the prompt.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        pass

    @abstractmethod
    def next_token_logprobs_sync(self, token_ids):
        """Request log probabilities of next token synchronously.

        Args:
            token_ids (list[int]): A list of token IDs representing the prompt.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        pass

    async def batch_next_token_logprobs(self, token_ids_list):
        """Batch request log probabilities for multiple token sequences asynchronously.

        Args:
            token_ids_list (list[list[int]]): A list of token ID lists.

        Returns:
            (torch.Tensor): A tensor of log probability tensors.
        """
        logprobs = await asyncio.gather(
            *[self.next_token_logprobs(token_ids) for token_ids in token_ids_list]
        )

        return torch.stack(logprobs)

    def batch_next_token_logprobs_sync(self, token_ids_list):
        """Batch request log probabilities for multiple token sequences synchronously.

        Args:
            token_ids_list (list[list[int]]): A list of token ID lists.

        Returns:
            (torch.Tensor): A tensor of log probability tensors.
        """
        return torch.stack(
            [self.next_token_logprobs_sync(token_ids) for token_ids in token_ids_list]
        )

    def clear_cache(self):
        """Clear any caches used by the language model. No-op in base class."""
        pass  # pragma: no cover

    async def sample(
        self, prompt_token_ids, max_tokens, eos_token_ids, temperature=1.0, seed=None
    ):
        """Sample from the language model.

        Args:
            prompt_token_ids (list[int]): The token IDs of the prompt.
            eos_token_ids (list[int]): The token IDs of the end-of-sequence tokens.
            temperature (float, optional): The temperature to use to rescale the logits. Defaults to 1.0.
            max_tokens (int): The maximum number of tokens to generate.
            seed (int, optional): The seed for the random number generator. Defaults to None.

        Returns:
            (list[int]): The sampled token IDs.
        """
        if seed is not None:
            generator = torch.Generator()
            generator.manual_seed(seed)
        else:
            generator = None

        generated_token_ids = []
        for _ in range(max_tokens):
            logprobs = await self.next_token_logprobs(
                prompt_token_ids + generated_token_ids
            )
            probs = torch.softmax(logprobs / temperature, dim=-1)
            next_token_id = torch.multinomial(
                probs.cpu() if seed is not None else probs,
                num_samples=1,
                generator=generator,
            ).item()
            if next_token_id in eos_token_ids:
                break
            generated_token_ids.append(next_token_id)

        return generated_token_ids

    async def batch_sample(
        self,
        prompt_token_ids_list,
        max_tokens,
        eos_token_ids,
        temperature=1.0,
        seed=None,
    ):
        """Batch sample from the language model.

        Args:
            prompt_token_ids_list (list[list[int]]): The token IDs of the prompts.
            max_tokens (int): The maximum number of tokens to generate.
            eos_token_ids (list[int]): The token IDs of the end-of-sequence token.
            temperature (float): The temperature to use for the logits.
            seed (int, optional): The seed for the random number generator. Defaults to None.

        Returns:
            (list[list[int]]): The sampled token IDs.
        """
        return await asyncio.gather(
            *[
                self.sample(
                    prompt_token_ids=prompt_token_ids,
                    max_tokens=max_tokens,
                    eos_token_ids=eos_token_ids,
                    temperature=temperature,
                    seed=seed,
                )
                for prompt_token_ids in prompt_token_ids_list
            ]
        )

next_token_logprobs(token_ids) abstractmethod async

Request log probabilities of next token asynchronously.

Parameters:

Name Type Description Default
token_ids list[int]

A list of token IDs representing the prompt.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm/backend/llm/base.py
@abstractmethod
async def next_token_logprobs(self, token_ids):
    """Request log probabilities of next token asynchronously.

    Args:
        token_ids (list[int]): A list of token IDs representing the prompt.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    pass

next_token_logprobs_sync(token_ids) abstractmethod

Request log probabilities of next token synchronously.

Parameters:

Name Type Description Default
token_ids list[int]

A list of token IDs representing the prompt.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm/backend/llm/base.py
@abstractmethod
def next_token_logprobs_sync(self, token_ids):
    """Request log probabilities of next token synchronously.

    Args:
        token_ids (list[int]): A list of token IDs representing the prompt.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    pass

batch_next_token_logprobs(token_ids_list) async

Batch request log probabilities for multiple token sequences asynchronously.

Parameters:

Name Type Description Default
token_ids_list list[list[int]]

A list of token ID lists.

required

Returns:

Type Description
Tensor

A tensor of log probability tensors.

Source code in genlm/backend/llm/base.py
async def batch_next_token_logprobs(self, token_ids_list):
    """Batch request log probabilities for multiple token sequences asynchronously.

    Args:
        token_ids_list (list[list[int]]): A list of token ID lists.

    Returns:
        (torch.Tensor): A tensor of log probability tensors.
    """
    logprobs = await asyncio.gather(
        *[self.next_token_logprobs(token_ids) for token_ids in token_ids_list]
    )

    return torch.stack(logprobs)

batch_next_token_logprobs_sync(token_ids_list)

Batch request log probabilities for multiple token sequences synchronously.

Parameters:

Name Type Description Default
token_ids_list list[list[int]]

A list of token ID lists.

required

Returns:

Type Description
Tensor

A tensor of log probability tensors.

Source code in genlm/backend/llm/base.py
def batch_next_token_logprobs_sync(self, token_ids_list):
    """Batch request log probabilities for multiple token sequences synchronously.

    Args:
        token_ids_list (list[list[int]]): A list of token ID lists.

    Returns:
        (torch.Tensor): A tensor of log probability tensors.
    """
    return torch.stack(
        [self.next_token_logprobs_sync(token_ids) for token_ids in token_ids_list]
    )

clear_cache()

Clear any caches used by the language model. No-op in base class.

Source code in genlm/backend/llm/base.py
def clear_cache(self):
    """Clear any caches used by the language model. No-op in base class."""
    pass  # pragma: no cover

sample(prompt_token_ids, max_tokens, eos_token_ids, temperature=1.0, seed=None) async

Sample from the language model.

Parameters:

Name Type Description Default
prompt_token_ids list[int]

The token IDs of the prompt.

required
eos_token_ids list[int]

The token IDs of the end-of-sequence tokens.

required
temperature float

The temperature to use to rescale the logits. Defaults to 1.0.

1.0
max_tokens int

The maximum number of tokens to generate.

required
seed int

The seed for the random number generator. Defaults to None.

None

Returns:

Type Description
list[int]

The sampled token IDs.

Source code in genlm/backend/llm/base.py
async def sample(
    self, prompt_token_ids, max_tokens, eos_token_ids, temperature=1.0, seed=None
):
    """Sample from the language model.

    Args:
        prompt_token_ids (list[int]): The token IDs of the prompt.
        eos_token_ids (list[int]): The token IDs of the end-of-sequence tokens.
        temperature (float, optional): The temperature to use to rescale the logits. Defaults to 1.0.
        max_tokens (int): The maximum number of tokens to generate.
        seed (int, optional): The seed for the random number generator. Defaults to None.

    Returns:
        (list[int]): The sampled token IDs.
    """
    if seed is not None:
        generator = torch.Generator()
        generator.manual_seed(seed)
    else:
        generator = None

    generated_token_ids = []
    for _ in range(max_tokens):
        logprobs = await self.next_token_logprobs(
            prompt_token_ids + generated_token_ids
        )
        probs = torch.softmax(logprobs / temperature, dim=-1)
        next_token_id = torch.multinomial(
            probs.cpu() if seed is not None else probs,
            num_samples=1,
            generator=generator,
        ).item()
        if next_token_id in eos_token_ids:
            break
        generated_token_ids.append(next_token_id)

    return generated_token_ids

batch_sample(prompt_token_ids_list, max_tokens, eos_token_ids, temperature=1.0, seed=None) async

Batch sample from the language model.

Parameters:

Name Type Description Default
prompt_token_ids_list list[list[int]]

The token IDs of the prompts.

required
max_tokens int

The maximum number of tokens to generate.

required
eos_token_ids list[int]

The token IDs of the end-of-sequence token.

required
temperature float

The temperature to use for the logits.

1.0
seed int

The seed for the random number generator. Defaults to None.

None

Returns:

Type Description
list[list[int]]

The sampled token IDs.

Source code in genlm/backend/llm/base.py
async def batch_sample(
    self,
    prompt_token_ids_list,
    max_tokens,
    eos_token_ids,
    temperature=1.0,
    seed=None,
):
    """Batch sample from the language model.

    Args:
        prompt_token_ids_list (list[list[int]]): The token IDs of the prompts.
        max_tokens (int): The maximum number of tokens to generate.
        eos_token_ids (list[int]): The token IDs of the end-of-sequence token.
        temperature (float): The temperature to use for the logits.
        seed (int, optional): The seed for the random number generator. Defaults to None.

    Returns:
        (list[list[int]]): The sampled token IDs.
    """
    return await asyncio.gather(
        *[
            self.sample(
                prompt_token_ids=prompt_token_ids,
                max_tokens=max_tokens,
                eos_token_ids=eos_token_ids,
                temperature=temperature,
                seed=seed,
            )
            for prompt_token_ids in prompt_token_ids_list
        ]
    )

MockAsyncLM

Bases: AsyncLM

Mock implementation of AsyncLM used for testing.

Source code in genlm/backend/llm/base.py
class MockAsyncLM(AsyncLM):
    """Mock implementation of AsyncLM used for testing."""

    def __init__(self, tokenizer):
        """Initialize a `MockAsyncLM` instance.

        Args:
            tokenizer: Hugging Face tokenizer instance
        """
        super().__init__(tokenizer)
        self._rng = np.random.RandomState(42)

    @classmethod
    def from_name(cls, model_name, **kwargs):
        """Create a MockAsyncLM instance over the vocabulary of the model's tokenizer.

        Args:
            model_name (str): Name of pretrained model to load tokenizer from
            **kwargs: Additional arguments passed to `MockAsyncLM` constructor

        Returns:
            (MockAsyncLM): `MockAsyncLM` instance initialized with tokenizer from `model_name`
        """
        from transformers import AutoTokenizer

        return cls(AutoTokenizer.from_pretrained(model_name), **kwargs)

    async def next_token_logprobs(self, token_ids):
        """Get next token log probabilities asynchronously.

        Args:
            token_ids (list[int]): Input token IDs.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        return self._get_logprobs(token_ids)

    def next_token_logprobs_sync(self, token_ids):
        """Get next token log probabilities synchronously.

        Args:
            token_ids (list[int]): Input token IDs.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        return self._get_logprobs(token_ids)

    def _get_logprobs(self, token_ids):
        """Generate random but deterministic log probabilities for given tokens.

        Uses token_ids to seed the random generator, ensuring same inputs produce same outputs.

        Args:
            token_ids (list[int]): Input token IDs.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        seed = sum([(i + 1) * t for i, t in enumerate(token_ids)])
        self._rng.seed(seed)
        logits = torch.from_numpy(
            self._rng.rand(len(self.tokenizer)).astype(np.float32)
        )
        return torch.log_softmax(logits, dim=-1)

__init__(tokenizer)

Initialize a MockAsyncLM instance.

Parameters:

Name Type Description Default
tokenizer

Hugging Face tokenizer instance

required
Source code in genlm/backend/llm/base.py
def __init__(self, tokenizer):
    """Initialize a `MockAsyncLM` instance.

    Args:
        tokenizer: Hugging Face tokenizer instance
    """
    super().__init__(tokenizer)
    self._rng = np.random.RandomState(42)

from_name(model_name, **kwargs) classmethod

Create a MockAsyncLM instance over the vocabulary of the model's tokenizer.

Parameters:

Name Type Description Default
model_name str

Name of pretrained model to load tokenizer from

required
**kwargs

Additional arguments passed to MockAsyncLM constructor

{}

Returns:

Type Description
MockAsyncLM

MockAsyncLM instance initialized with tokenizer from model_name

Source code in genlm/backend/llm/base.py
@classmethod
def from_name(cls, model_name, **kwargs):
    """Create a MockAsyncLM instance over the vocabulary of the model's tokenizer.

    Args:
        model_name (str): Name of pretrained model to load tokenizer from
        **kwargs: Additional arguments passed to `MockAsyncLM` constructor

    Returns:
        (MockAsyncLM): `MockAsyncLM` instance initialized with tokenizer from `model_name`
    """
    from transformers import AutoTokenizer

    return cls(AutoTokenizer.from_pretrained(model_name), **kwargs)

next_token_logprobs(token_ids) async

Get next token log probabilities asynchronously.

Parameters:

Name Type Description Default
token_ids list[int]

Input token IDs.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm/backend/llm/base.py
async def next_token_logprobs(self, token_ids):
    """Get next token log probabilities asynchronously.

    Args:
        token_ids (list[int]): Input token IDs.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    return self._get_logprobs(token_ids)

next_token_logprobs_sync(token_ids)

Get next token log probabilities synchronously.

Parameters:

Name Type Description Default
token_ids list[int]

Input token IDs.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm/backend/llm/base.py
def next_token_logprobs_sync(self, token_ids):
    """Get next token log probabilities synchronously.

    Args:
        token_ids (list[int]): Input token IDs.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    return self._get_logprobs(token_ids)

AsyncMlxLM

Bases: AsyncLM

Asynchronous MLX-based language model wrapper.

This class provides an async interface to MLX language models with automatic batching, caching, and KV cache management. It extends AsyncLM to provide efficient batched inference with prefix caching.

The model automatically batches concurrent requests and uses a trie-based cache to store computed log probabilities and KV states for reuse.

Source code in genlm/backend/llm/mlx.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
class AsyncMlxLM(AsyncLM):
    """Asynchronous MLX-based language model wrapper.

    This class provides an async interface to MLX language models with
    automatic batching, caching, and KV cache management. It extends
    AsyncLM to provide efficient batched inference with prefix caching.

    The model automatically batches concurrent requests and uses a trie-based
    cache to store computed log probabilities and KV states for reuse.
    """

    def __init__(
        self,
        mlx_lm_model,
        tokenizer,
        batch_size=5,
        timeout=0.001,
        prefill_step_size=2048,
        cache_size=400,
    ):
        """Initialize an `AsyncMlxLM` instance.

        Args:
            mlx_lm_model: The MLX language model instance.
            tokenizer: The tokenizer for encoding/decoding text.
            batch_size (int, optional): Maximum number of queries to batch
                together.
            timeout (float, optional): Maximum time in seconds to wait
                before processing a batch, even if batch_size is not met.
            prefill_step_size (int, optional): Number of tokens to process
                per step during prompt prefilling.
            cache_size (int, optional): Maximum number of KV cache entries
                to keep in memory.
        """
        self.mlx_lm_model = mlx_lm_model
        self.tokenizer = tokenizer
        self.cache = DynamicTokenTrie()
        self.generation_stream = mx.new_stream(mx.default_device())
        self.queries = []
        self.timeout = timeout
        self.timer = None
        self.prefill_step_size = prefill_step_size
        self.cache_size = cache_size

        self.batch_size = batch_size
        self.kv_cachable = self._kv_cachable(self.mlx_lm_model)
        if not self.kv_cachable:
            warnings.warn(
                f"Model {type(self.mlx_lm_model).__name__} does not support KV caching; "
                f"prefix caching will be disabled.",
                UserWarning,
                stacklevel=2,
            )
        super().__init__(tokenizer=self.tokenizer)

    @classmethod
    def from_name(cls, model_name, **kwargs):
        """Create an `AsyncMlxLM` instance from a model name.

        Args:
            model_name (str): Name of the model to load. Can be a Hugging Face
                model identifier or local path.
            **kwargs: Additional arguments passed to `AsyncMlxLM` constructor,
                such as `batch_size`, `timeout`, `prefill_step_size`, `cache_size`.

        Returns:
            AsyncMlxLM: An `AsyncMlxLM` instance with the loaded model and tokenizer.
        """

        model, tokenizer = mlx_lm.load(model_name)
        return cls(model, tokenizer, **kwargs)

    @staticmethod
    def _to_torch(logprobs):
        """Convert MLX arrays into PyTorch tensors."""
        if logprobs.dtype == mx.bfloat16:
            logprobs = logprobs.astype(mx.float16)
        return torch.tensor(logprobs)

    @staticmethod
    def _kv_cachable(mlx_lm_model):
        """Check if an MLX model supports KV cache storage.

        A model is KV-cacheable if all its cache layers are KVCache or
        RotatingKVCache with keep=0.
        """
        if not hasattr(mlx_lm_model, "make_cache"):
            return True
        cache = mlx_lm_model.make_cache()
        return all(
            isinstance(c, KVCache)
            or (isinstance(c, RotatingKVCache) and c.keep == 0)
            for c in cache
        )

    def clear_cache(self):
        """Clear the output cache and MLX device cache.

        This method resets the internal token trie cache and clears
        any cached arrays on the MLX device to free memory.
        """
        if self.cache is not None:
            self.cache = DynamicTokenTrie()
        mx.clear_cache()

    def walk_cache(self, token_ids):
        """Walk the cache tree to find the deepest node matching a sequence of tokens.

        Args:
            token_ids (list[int]): Sequence of token IDs to follow in the cache tree

        Returns:
            tuple: A 5-tuple containing:
                - node: The deepest node in the cache tree that matches
                    the token sequence, irregardless of whether its kv is cached or not
                - next_token_index: Number of tokens matched from the start of token_ids
                - past_kvs: Past key/value states concatenated from cached nodes, or None if no cached states were found
                - kv_node: The cache node where KV states start
                - kv_next_token_index: Number of tokens matched from the start of token_ids for the KV states
        """
        # Walk while tokens can be found
        node = self.cache
        kv_next_token_index = 0
        kv_node = node
        collecting = True
        next_token_index = 0
        past_kvs = []

        while next_token_index < len(token_ids):
            if node.past_key_values is not None and collecting:
                past_kvs.append(node.past_key_values)
                kv_node = node
                kv_next_token_index = next_token_index
            elif next_token_index > 0:
                collecting = False
            if node.has_token(token_ids[next_token_index]):
                node = node.get_token(token_ids[next_token_index])
                next_token_index += 1
            else:
                break

        past_kvs = None if len(past_kvs) == 0 else mx.concatenate(past_kvs, axis=3)

        return node, next_token_index, past_kvs, kv_node, kv_next_token_index

    def cache_kv(self, token_ids):
        """Pre-compute and cache KV states for a given token sequence."""
        query = Query(token_ids, None, None, self.cache, 0)
        self._batch_logits_custom([query])

    def reset_async_queries(self):
        """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing
        to completion."""
        self.queries = []

    def add_to_cache(self, queries, prompt_cache=None, logprobs=None):
        """Add computed log probabilities and KV states to the cache tree."""
        left_paddings = prompt_cache[0].left_padding.tolist()
        for i, query in enumerate(queries):
            token_ids, node, next_token_index = (
                query.prompt,
                query.node,
                query.next_token_index,
            )
            if node is None or next_token_index is None:
                node = self.cache
                next_token_index = 0
            lp = left_paddings[i]
            if prompt_cache is not None and self.kv_cachable:
                keys = [
                    c.keys[i, :, lp + next_token_index : lp + len(token_ids), :]
                    for c in prompt_cache
                ]
                values = [
                    c.values[i, :, lp + next_token_index : lp + len(token_ids), :]
                    for c in prompt_cache
                ]
                keys = mx.stack(keys, axis=0)
                values = mx.stack(values, axis=0)
                keys_values = mx.stack([keys, values], axis=0)
                node.extend_cache(
                    next_token_index, token_ids, logprobs[i], keys_values
                )
            else:
                node.extend_cache(next_token_index, token_ids, logprobs[i])

        self.cache.evict_lru_kv(self.cache_size)

    def _process_kv(self, left_paddings, prompt_cache, pasts=None, step_size=256):
        """Process and integrate past KV cache states into prompt cache.

        This method takes past key-value cache states from the cache tree
        and integrates them into the prompt cache for efficient prefix
        reuse. It handles padding and alignment of cache states across
        different query lengths.

        Args:
            left_paddings (list[int]): Left padding amounts for each query
                in the batch.
            prompt_cache (list): List of cache objects to update with
                past states.
            pasts (list[mx.array], optional): List of past KV cache states,
                one per query.
            step_size (int, optional): Step size for cache size alignment.

        Returns:
            tuple: A 2-tuple containing:
                - list: Updated prompt_cache objects
                - cached_len: Number of tokens that were cached
        """
        if pasts is None or all(past is None for past in pasts):
            return prompt_cache, 0
        max_match_lengths = [0 if past is None else past.shape[3] for past in pasts]
        min_pos_cached = min(
            ml + lp for ml, lp in zip(max_match_lengths, left_paddings)
        )
        cache_grabs = [max(min_pos_cached - lp, 0) for lp in left_paddings]
        non_zero_index = next(
            (i for i, grab in enumerate(cache_grabs) if grab), None
        )
        if non_zero_index is None:
            return prompt_cache, 0
        _, num_layers, N, _, D = pasts[non_zero_index].shape
        cache_size = (step_size + min_pos_cached - 1) // step_size * step_size
        right_paddings = [
            max(cache_size - lp - max_len, 0)
            for lp, max_len in zip(left_paddings, max_match_lengths)
        ]
        padded_pasts = []
        for past, lp, rp in zip(pasts, left_paddings, right_paddings):
            if past is None:
                padded_pasts.append(mx.zeros((2, num_layers, N, cache_size, D)))
            else:
                padded_pasts.append(
                    mx.pad(
                        past[:, :, :, : cache_size - lp, :],
                        ((0, 0), (0, 0), (0, 0), (lp, rp), (0, 0)),
                    )
                )

        padded_pasts = mx.stack(padded_pasts, axis=2)
        for i, cache in enumerate(prompt_cache):
            cache.keys = padded_pasts[0, i]
            cache.values = padded_pasts[1, i]
            cache.offset += min_pos_cached
            cache._idx += min_pos_cached
        return prompt_cache, min_pos_cached

    def _process_prompts(self, queries):
        """Process a batch of prompts and compute next-token log probabilities."""
        inputs = [q.prompt for q in queries]
        pasts = [q.past for q in queries]
        lengths = [len(p) for p in inputs]
        max_length = max(lengths)
        left_padding = [max_length - length for length in lengths]
        prompt_cache = _make_cache(self.mlx_lm_model, left_padding)
        inputs_padded = _left_pad_prompts(inputs, max_length=max_length)

        if self.kv_cachable:
            prompt_cache, cached_len = self._process_kv(
                left_padding, prompt_cache, pasts
            )
        else:
            cached_len = 0
        inputs_padded = inputs_padded[:, cached_len:]

        while inputs_padded.shape[1] > 1:
            n_to_process = min(self.prefill_step_size, inputs_padded.shape[1] - 1)
            self.mlx_lm_model(inputs_padded[:, :n_to_process], cache=prompt_cache)
            mx.eval([c.state for c in prompt_cache])
            inputs_padded = inputs_padded[:, n_to_process:]

        logits = self.mlx_lm_model(inputs_padded, cache=prompt_cache)
        logits = logits[:, -1, :]
        logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
        mx.async_eval(logprobs)

        return logprobs, prompt_cache

    def _batch_logits_custom(
        self,
        queries,
    ):
        """Compute next-token log probabilities for each query in a batch and add to cache.
        Args:
            queries (list[Query]): List of query objects to process.
        Returns:
            logprobs (list[torch.Tensor]): List of normalized log probability tensors."""
        with wired_limit(self.mlx_lm_model, [self.generation_stream]):
            logprobs, prompt_cache = self._process_prompts(queries)
            logprobs = AsyncMlxLM._to_torch(logprobs)
        mx.clear_cache()
        self.add_to_cache(queries, prompt_cache, logprobs)
        return logprobs

    def batch_evaluate_queries(self):
        """Process a batch of queued language model queries."""

        queries, self.queries = self.queries, []
        if len(queries) == 0:
            return

        query_groups = defaultdict(list)
        for query in queries:
            key = tuple(query.prompt)
            query_groups[key].append(query)

        # Use one representative query from each group
        unique_queries = [group[0] for group in query_groups.values()]

        results = self._batch_logits_custom(unique_queries)

        assert len(results) == len(unique_queries)

        results = results
        for i, q in enumerate(unique_queries):
            for dup_query in query_groups[tuple(q.prompt)]:
                dup_query.future.set_result(results[i])

    def add_query(self, query):
        """Add a query to be evaluated in the next batch and reset the timeout."""
        self.queries.append(query)

        if self.timer:
            self.timer.cancel()
            self.timer = None
        if len(self.queries) >= self.batch_size:
            self.batch_evaluate_queries()
        else:
            self.timer = asyncio.get_running_loop().call_later(
                self.timeout, lambda: self.batch_evaluate_queries()
            )

    async def next_token_logprobs(self, token_ids):
        """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`.

        Args:
            token_ids (list[int]): a list of token ids, representing a prompt to the language model.

        Returns:
            logprobs (torch.Tensor): a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        node, next_token_index, past, kv_node, kv_next_token_index = (
            self.walk_cache(token_ids)
        )
        if next_token_index == len(token_ids) and node.logprobs is not None:
            return node.logprobs

        future = asyncio.get_running_loop().create_future()
        query = Query(token_ids, future, past, kv_node, kv_next_token_index)
        self.add_query(query)
        logprobs = await future

        return logprobs

    def next_token_logprobs_sync(self, token_ids):
        """Request log probabilities of next token synchronously.

        Args:
            token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        node, next_token_index, past, kv_node, kv_next_token_index = (
            self.walk_cache(token_ids)
        )
        if next_token_index == len(token_ids) and node.logprobs is not None:
            return node.logprobs

        query = Query(token_ids, None, past, kv_node, kv_next_token_index)
        logprobs = self._batch_logits_custom([query])[0]

        return logprobs

    async def sample(
        self,
        prompt_token_ids,
        max_tokens,
        eos_token_ids,
        temperature=1.0,
        seed=None,
    ):
        """Sample from the language model.

        Args:
            prompt_token_ids (list[int]): The token IDs of the prompt to
                start generation from.
            max_tokens (int): The maximum number of tokens to generate.
            eos_token_ids (list[int]): The token IDs that signal
                end-of-sequence. Generation stops when one of these is
                sampled.
            temperature (float, optional): The temperature to use for
                sampling. Higher values make the distribution more uniform,
                lower values make it more peaked. Defaults to 1.0.
            seed (int, optional): The seed for the random number generator.
                If provided, sets the random seed before sampling.
                Defaults to None.

        Returns:
            (list[int]): The sampled token IDs.
        """

        if seed is not None:
            mx.random.seed(seed)

        sampler = make_sampler(temp=temperature)
        prompt_token_ids_array = mx.array(prompt_token_ids)
        token_generator = generate_step(
            prompt_token_ids_array,
            self.mlx_lm_model,
            max_tokens=max_tokens,
            sampler=sampler,
        )
        generated_token_ids = []
        for sampled, _ in token_generator:
            if sampled in eos_token_ids:
                break
            generated_token_ids.append(sampled)
        return generated_token_ids

__init__(mlx_lm_model, tokenizer, batch_size=5, timeout=0.001, prefill_step_size=2048, cache_size=400)

Initialize an AsyncMlxLM instance.

Parameters:

Name Type Description Default
mlx_lm_model

The MLX language model instance.

required
tokenizer

The tokenizer for encoding/decoding text.

required
batch_size int

Maximum number of queries to batch together.

5
timeout float

Maximum time in seconds to wait before processing a batch, even if batch_size is not met.

0.001
prefill_step_size int

Number of tokens to process per step during prompt prefilling.

2048
cache_size int

Maximum number of KV cache entries to keep in memory.

400
Source code in genlm/backend/llm/mlx.py
def __init__(
    self,
    mlx_lm_model,
    tokenizer,
    batch_size=5,
    timeout=0.001,
    prefill_step_size=2048,
    cache_size=400,
):
    """Initialize an `AsyncMlxLM` instance.

    Args:
        mlx_lm_model: The MLX language model instance.
        tokenizer: The tokenizer for encoding/decoding text.
        batch_size (int, optional): Maximum number of queries to batch
            together.
        timeout (float, optional): Maximum time in seconds to wait
            before processing a batch, even if batch_size is not met.
        prefill_step_size (int, optional): Number of tokens to process
            per step during prompt prefilling.
        cache_size (int, optional): Maximum number of KV cache entries
            to keep in memory.
    """
    self.mlx_lm_model = mlx_lm_model
    self.tokenizer = tokenizer
    self.cache = DynamicTokenTrie()
    self.generation_stream = mx.new_stream(mx.default_device())
    self.queries = []
    self.timeout = timeout
    self.timer = None
    self.prefill_step_size = prefill_step_size
    self.cache_size = cache_size

    self.batch_size = batch_size
    self.kv_cachable = self._kv_cachable(self.mlx_lm_model)
    if not self.kv_cachable:
        warnings.warn(
            f"Model {type(self.mlx_lm_model).__name__} does not support KV caching; "
            f"prefix caching will be disabled.",
            UserWarning,
            stacklevel=2,
        )
    super().__init__(tokenizer=self.tokenizer)

from_name(model_name, **kwargs) classmethod

Create an AsyncMlxLM instance from a model name.

Parameters:

Name Type Description Default
model_name str

Name of the model to load. Can be a Hugging Face model identifier or local path.

required
**kwargs

Additional arguments passed to AsyncMlxLM constructor, such as batch_size, timeout, prefill_step_size, cache_size.

{}

Returns:

Name Type Description
AsyncMlxLM

An AsyncMlxLM instance with the loaded model and tokenizer.

Source code in genlm/backend/llm/mlx.py
@classmethod
def from_name(cls, model_name, **kwargs):
    """Create an `AsyncMlxLM` instance from a model name.

    Args:
        model_name (str): Name of the model to load. Can be a Hugging Face
            model identifier or local path.
        **kwargs: Additional arguments passed to `AsyncMlxLM` constructor,
            such as `batch_size`, `timeout`, `prefill_step_size`, `cache_size`.

    Returns:
        AsyncMlxLM: An `AsyncMlxLM` instance with the loaded model and tokenizer.
    """

    model, tokenizer = mlx_lm.load(model_name)
    return cls(model, tokenizer, **kwargs)

clear_cache()

Clear the output cache and MLX device cache.

This method resets the internal token trie cache and clears any cached arrays on the MLX device to free memory.

Source code in genlm/backend/llm/mlx.py
def clear_cache(self):
    """Clear the output cache and MLX device cache.

    This method resets the internal token trie cache and clears
    any cached arrays on the MLX device to free memory.
    """
    if self.cache is not None:
        self.cache = DynamicTokenTrie()
    mx.clear_cache()

walk_cache(token_ids)

Walk the cache tree to find the deepest node matching a sequence of tokens.

Parameters:

Name Type Description Default
token_ids list[int]

Sequence of token IDs to follow in the cache tree

required

Returns:

Name Type Description
tuple

A 5-tuple containing: - node: The deepest node in the cache tree that matches the token sequence, irregardless of whether its kv is cached or not - next_token_index: Number of tokens matched from the start of token_ids - past_kvs: Past key/value states concatenated from cached nodes, or None if no cached states were found - kv_node: The cache node where KV states start - kv_next_token_index: Number of tokens matched from the start of token_ids for the KV states

Source code in genlm/backend/llm/mlx.py
def walk_cache(self, token_ids):
    """Walk the cache tree to find the deepest node matching a sequence of tokens.

    Args:
        token_ids (list[int]): Sequence of token IDs to follow in the cache tree

    Returns:
        tuple: A 5-tuple containing:
            - node: The deepest node in the cache tree that matches
                the token sequence, irregardless of whether its kv is cached or not
            - next_token_index: Number of tokens matched from the start of token_ids
            - past_kvs: Past key/value states concatenated from cached nodes, or None if no cached states were found
            - kv_node: The cache node where KV states start
            - kv_next_token_index: Number of tokens matched from the start of token_ids for the KV states
    """
    # Walk while tokens can be found
    node = self.cache
    kv_next_token_index = 0
    kv_node = node
    collecting = True
    next_token_index = 0
    past_kvs = []

    while next_token_index < len(token_ids):
        if node.past_key_values is not None and collecting:
            past_kvs.append(node.past_key_values)
            kv_node = node
            kv_next_token_index = next_token_index
        elif next_token_index > 0:
            collecting = False
        if node.has_token(token_ids[next_token_index]):
            node = node.get_token(token_ids[next_token_index])
            next_token_index += 1
        else:
            break

    past_kvs = None if len(past_kvs) == 0 else mx.concatenate(past_kvs, axis=3)

    return node, next_token_index, past_kvs, kv_node, kv_next_token_index

cache_kv(token_ids)

Pre-compute and cache KV states for a given token sequence.

Source code in genlm/backend/llm/mlx.py
def cache_kv(self, token_ids):
    """Pre-compute and cache KV states for a given token sequence."""
    query = Query(token_ids, None, None, self.cache, 0)
    self._batch_logits_custom([query])

reset_async_queries()

Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing to completion.

Source code in genlm/backend/llm/mlx.py
def reset_async_queries(self):
    """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing
    to completion."""
    self.queries = []

add_to_cache(queries, prompt_cache=None, logprobs=None)

Add computed log probabilities and KV states to the cache tree.

Source code in genlm/backend/llm/mlx.py
def add_to_cache(self, queries, prompt_cache=None, logprobs=None):
    """Add computed log probabilities and KV states to the cache tree."""
    left_paddings = prompt_cache[0].left_padding.tolist()
    for i, query in enumerate(queries):
        token_ids, node, next_token_index = (
            query.prompt,
            query.node,
            query.next_token_index,
        )
        if node is None or next_token_index is None:
            node = self.cache
            next_token_index = 0
        lp = left_paddings[i]
        if prompt_cache is not None and self.kv_cachable:
            keys = [
                c.keys[i, :, lp + next_token_index : lp + len(token_ids), :]
                for c in prompt_cache
            ]
            values = [
                c.values[i, :, lp + next_token_index : lp + len(token_ids), :]
                for c in prompt_cache
            ]
            keys = mx.stack(keys, axis=0)
            values = mx.stack(values, axis=0)
            keys_values = mx.stack([keys, values], axis=0)
            node.extend_cache(
                next_token_index, token_ids, logprobs[i], keys_values
            )
        else:
            node.extend_cache(next_token_index, token_ids, logprobs[i])

    self.cache.evict_lru_kv(self.cache_size)

batch_evaluate_queries()

Process a batch of queued language model queries.

Source code in genlm/backend/llm/mlx.py
def batch_evaluate_queries(self):
    """Process a batch of queued language model queries."""

    queries, self.queries = self.queries, []
    if len(queries) == 0:
        return

    query_groups = defaultdict(list)
    for query in queries:
        key = tuple(query.prompt)
        query_groups[key].append(query)

    # Use one representative query from each group
    unique_queries = [group[0] for group in query_groups.values()]

    results = self._batch_logits_custom(unique_queries)

    assert len(results) == len(unique_queries)

    results = results
    for i, q in enumerate(unique_queries):
        for dup_query in query_groups[tuple(q.prompt)]:
            dup_query.future.set_result(results[i])

add_query(query)

Add a query to be evaluated in the next batch and reset the timeout.

Source code in genlm/backend/llm/mlx.py
def add_query(self, query):
    """Add a query to be evaluated in the next batch and reset the timeout."""
    self.queries.append(query)

    if self.timer:
        self.timer.cancel()
        self.timer = None
    if len(self.queries) >= self.batch_size:
        self.batch_evaluate_queries()
    else:
        self.timer = asyncio.get_running_loop().call_later(
            self.timeout, lambda: self.batch_evaluate_queries()
        )

next_token_logprobs(token_ids) async

Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with await.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs Tensor

a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in genlm/backend/llm/mlx.py
async def next_token_logprobs(self, token_ids):
    """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`.

    Args:
        token_ids (list[int]): a list of token ids, representing a prompt to the language model.

    Returns:
        logprobs (torch.Tensor): a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    node, next_token_index, past, kv_node, kv_next_token_index = (
        self.walk_cache(token_ids)
    )
    if next_token_index == len(token_ids) and node.logprobs is not None:
        return node.logprobs

    future = asyncio.get_running_loop().create_future()
    query = Query(token_ids, future, past, kv_node, kv_next_token_index)
    self.add_query(query)
    logprobs = await future

    return logprobs

next_token_logprobs_sync(token_ids)

Request log probabilities of next token synchronously.

Parameters:

Name Type Description Default
token_ids list[int]

A list of token IDs, representing a prompt to the language model.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm/backend/llm/mlx.py
def next_token_logprobs_sync(self, token_ids):
    """Request log probabilities of next token synchronously.

    Args:
        token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    node, next_token_index, past, kv_node, kv_next_token_index = (
        self.walk_cache(token_ids)
    )
    if next_token_index == len(token_ids) and node.logprobs is not None:
        return node.logprobs

    query = Query(token_ids, None, past, kv_node, kv_next_token_index)
    logprobs = self._batch_logits_custom([query])[0]

    return logprobs

sample(prompt_token_ids, max_tokens, eos_token_ids, temperature=1.0, seed=None) async

Sample from the language model.

Parameters:

Name Type Description Default
prompt_token_ids list[int]

The token IDs of the prompt to start generation from.

required
max_tokens int

The maximum number of tokens to generate.

required
eos_token_ids list[int]

The token IDs that signal end-of-sequence. Generation stops when one of these is sampled.

required
temperature float

The temperature to use for sampling. Higher values make the distribution more uniform, lower values make it more peaked. Defaults to 1.0.

1.0
seed int

The seed for the random number generator. If provided, sets the random seed before sampling. Defaults to None.

None

Returns:

Type Description
list[int]

The sampled token IDs.

Source code in genlm/backend/llm/mlx.py
async def sample(
    self,
    prompt_token_ids,
    max_tokens,
    eos_token_ids,
    temperature=1.0,
    seed=None,
):
    """Sample from the language model.

    Args:
        prompt_token_ids (list[int]): The token IDs of the prompt to
            start generation from.
        max_tokens (int): The maximum number of tokens to generate.
        eos_token_ids (list[int]): The token IDs that signal
            end-of-sequence. Generation stops when one of these is
            sampled.
        temperature (float, optional): The temperature to use for
            sampling. Higher values make the distribution more uniform,
            lower values make it more peaked. Defaults to 1.0.
        seed (int, optional): The seed for the random number generator.
            If provided, sets the random seed before sampling.
            Defaults to None.

    Returns:
        (list[int]): The sampled token IDs.
    """

    if seed is not None:
        mx.random.seed(seed)

    sampler = make_sampler(temp=temperature)
    prompt_token_ids_array = mx.array(prompt_token_ids)
    token_generator = generate_step(
        prompt_token_ids_array,
        self.mlx_lm_model,
        max_tokens=max_tokens,
        sampler=sampler,
    )
    generated_token_ids = []
    for sampled, _ in token_generator:
        if sampled in eos_token_ids:
            break
        generated_token_ids.append(sampled)
    return generated_token_ids

load_model_by_name(name, backend=None, llm_opts=None)

Load a language model by name.

Parameters:

Name Type Description Default
name str

Hugging Face model name (e.g. "gpt2", "meta-llama/Llama-3.2-1B-Instruct")

required
backend str

Backend to use for inference. Can be "vllm", "hf" or "mock". If None, defaults to "vllm" if CUDA is available, otherwise "hf".

None
llm_opts dict

Additional options to pass to the backend constructor. See AsyncVirtualLM and AsyncTransformer documentation for details.

None

Returns:

Type Description
AsyncLM

An asynchronous language model.

Raises:

Type Description
ValueError

If an invalid backend is specified.

Source code in genlm/backend/llm/__init__.py
def load_model_by_name(name, backend=None, llm_opts=None):
    """Load a language model by name.

    Args:
        name (str): Hugging Face model name (e.g. "gpt2", "meta-llama/Llama-3.2-1B-Instruct")
        backend (str, optional): Backend to use for inference. Can be "vllm", "hf" or "mock".
            If None, defaults to "vllm" if CUDA is available, otherwise "hf".
        llm_opts (dict, optional): Additional options to pass to the backend constructor.
            See AsyncVirtualLM and AsyncTransformer documentation for details.

    Returns:
        (AsyncLM): An asynchronous language model.

    Raises:
        (ValueError): If an invalid backend is specified.
    """
    if backend is None:
        backend = "vllm" if torch.cuda.is_available() else "hf"

    if llm_opts is None:
        llm_opts = {}

    if backend == "vllm":
        return AsyncVirtualLM.from_name(name, **llm_opts)
    elif backend == "hf":
        return AsyncTransformer.from_name(name, **llm_opts)
    elif backend == "mock":
        return MockAsyncLM.from_name(name, **llm_opts)
    elif backend == "mlx":
        return AsyncMlxLM.from_name(name, **llm_opts)
    else:
        raise ValueError(f"Invalid backend: {backend}")