Skip to content

backend

AsyncVirtualLM

Bases: AsyncLM

A wrapper around vLLM's AsyncLLMEngine for asynchronous next token log probability computations.

This class provides an asynchronous interface for computing log probabilities using vLLM's engine. It is optimized for next token log probability computations and supports caching of results (outputs and KV).

Source code in genlm/backend/llm/vllm.py
class AsyncVirtualLM(AsyncLM):
    """A wrapper around vLLM's `AsyncLLMEngine` for asynchronous next token log probability computations.

    This class provides an asynchronous interface for computing log probabilities using vLLM's engine.
    It is optimized for next token log probability computations and supports caching of results (outputs and KV).
    """

    default_params = SamplingParams(
        max_tokens=1, n=1, logprobs=1, detokenize=False, stop=None, ignore_eos=True
    )

    def __init__(self, async_llm_engine, cache_size=0, cache_opts={}):
        """Initialize an `AsyncVirtualLM` instance.

        Args:
            async_llm_engine (AsyncLLMEngine): The async vLLM engine instance.
            cache_size (int, optional): Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.
            cache_opts (dict, optional): Additional options to pass to the [`OutputCache`][genlm.backend.cache.OutputCache] constructor. Defaults to {}.

        Note:
            The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine.
        """
        self.async_llm_engine = async_llm_engine
        self.tokenizer = async_llm_engine.engine.get_tokenizer()
        self.request_counter = Counter()
        self.custom_sampler = DeferredSampler()
        self.cache = (
            OutputCache(maxsize=cache_size, **cache_opts)
            if cache_size > 0
            else None
        )

        async_llm_engine.engine.log_stats = False

        super().__init__(tokenizer=self.tokenizer)

    @classmethod
    def from_name(cls, model_name, engine_opts=None, **kwargs):
        """Create a `AsyncVirtualLM` instance from a model name.

        Args:
            model_name (str): Name of the model to load.
            engine_opts (dict): Additional options to pass to the `AsyncLLMEngine`. The engine will be
                configured with prefix caching enabled and async output processing disabled by default.
            **kwargs: Additional arguments passed to `AsyncVirtualLM` constructor.

        Returns:
            (AsyncVirtualLM): An `AsyncVirtualLM` instance.
        """
        if not HAS_VLLM:
            raise ImportError(  # pragma: no cover
                "vLLM not available. Install vLLM or use AsyncTransformer instead."
            )

        if engine_opts is not None and "enable_chunked_prefill" in engine_opts:
            if engine_opts["enable_chunked_prefill"]:
                warnings.warn(
                    "Setting enable_chunked_prefill to True may interfere with AsyncVirtualLM's "
                    "custom sampling functionality."
                )

        engine_opts = {
            "enable_prefix_caching": True,
            "disable_log_requests": True,
            "disable_async_output_proc": True,
            # Need to disable chunked prefill to avoid issues
            # with our custom sampler.
            "enable_chunked_prefill": False,
            **(engine_opts or {}),
        }

        engine = AsyncLLMEngine.from_engine_args(
            AsyncEngineArgs(model=model_name, tokenizer=model_name, **engine_opts)
        )

        return cls(engine, **kwargs)

    async def next_token_logprobs(self, token_ids):
        """Request log probabilities of next token asynchronously with output caching.

        Args:
            token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model.

        Returns:
            result (torch.Tensor): Normalized log probability tensor.

        Warning:
            Do not use `asyncio.run(next_token_logprobs())` as it may interfere with vLLM's background loop.
            For synchronous usage, use the `next_token_logprobs_sync()` method instead.
        """
        key = tuple(token_ids)

        if self.cache is not None and key in self.cache:
            return self.cache[key]

        result = await self._next_token_logprobs(key)

        if self.cache is not None:
            self.cache[key] = result

        return result

    async def _next_token_logprobs(self, token_ids):
        """Request log probabilities of next token asynchronously.

        Args:
            token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        req_id = str(next(self.request_counter))
        prompt = TokensPrompt(prompt_token_ids=token_ids)

        outputs = []
        with self._optimized_sampling_context():
            async for output in self.async_llm_engine.generate(
                prompt=prompt,
                sampling_params=self.default_params,
                request_id=req_id,
            ):
                if output.finished:
                    outputs.append(output)

        return self._validate_outputs(outputs)

    def next_token_logprobs_sync(self, token_ids):
        """Request log probabilities of next token synchronously.

        Args:
            token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        return self.batch_next_token_logprobs_sync([token_ids])[0]

    def batch_next_token_logprobs_sync(self, token_ids_list):
        """
        Request log probabilities of next tokens in a batch synchronously.

        Args:
            token_ids_list (list[list[int]]): A list of token ID lists, each representing a prompt to the language model.

        Returns:
            (torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list.
        """
        req_ids = []
        for token_ids in token_ids_list:
            req_id = str(next(self.request_counter))
            req_ids.append(req_id)
            self.async_llm_engine.engine.add_request(
                prompt=TokensPrompt(prompt_token_ids=token_ids),
                params=self.default_params,
                request_id=req_id,
            )

        req_id2outputs = {}
        with self._optimized_sampling_context():
            while self.async_llm_engine.engine.has_unfinished_requests():
                output = self.async_llm_engine.engine.step()
                for out in output:
                    if out.finished:
                        assert out.request_id not in req_id2outputs, (
                            f"Duplicate outputs for request {out.request_id}"
                        )
                        assert out.request_id in req_ids, (
                            f"{out.request_id} not in requested IDs"
                        )
                        req_id2outputs[out.request_id] = out

        logprobs = [
            self._validate_outputs([req_id2outputs[req_id]]) for req_id in req_ids
        ]

        return torch.stack(logprobs)

    @contextmanager
    def _optimized_sampling_context(self):
        """Context manager for optimized sampling configuration."""
        model = self.async_llm_engine.engine.model_executor.driver_worker.model_runner.model
        original_sampler = model.sampler
        try:
            model.sampler = self.custom_sampler
            yield
        finally:
            model.sampler = original_sampler

    def _validate_outputs(self, outputs):
        """Validate and extract logprobs from a vLLM output.

        Args:
            outputs: List of sequence group outputs from vLLM generation

        Returns:
            Tensor of log probabilities for the next token

        Raises:
            AssertionError: If output structure doesn't match expected format
        """
        assert len(outputs) == 1, "Expected exactly one sequence group"
        seq_group = outputs[0]

        assert len(seq_group.outputs) == 1, (
            "Expected exactly one sequence in output"
        )
        sequence = seq_group.outputs[0]

        assert len(sequence.logprobs) == 1, "Expected exactly one set of logprobs"
        token_logprobs = sequence.logprobs[0].logprobs

        return token_logprobs

    def clear_cache(self):
        """Clear output cache."""
        if self.cache:
            self.cache.clear()

    def __del__(self):
        """Clean up resources on deletion."""
        self._cleanup_engine()

    def _cleanup_engine(self):
        """Clean up the vLLM engine and associated resources."""
        if async_engine := getattr(self, "async_llm_engine", None):
            async_engine.shutdown_background_loop()
            destroy_model_parallel()
            destroy_distributed_environment()

__init__(async_llm_engine, cache_size=0, cache_opts={})

Initialize an AsyncVirtualLM instance.

Parameters:

Name Type Description Default
async_llm_engine AsyncLLMEngine

The async vLLM engine instance.

required
cache_size int

Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.

0
cache_opts dict

Additional options to pass to the OutputCache constructor. Defaults to {}.

{}
Note

The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine.

Source code in genlm/backend/llm/vllm.py
def __init__(self, async_llm_engine, cache_size=0, cache_opts={}):
    """Initialize an `AsyncVirtualLM` instance.

    Args:
        async_llm_engine (AsyncLLMEngine): The async vLLM engine instance.
        cache_size (int, optional): Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.
        cache_opts (dict, optional): Additional options to pass to the [`OutputCache`][genlm.backend.cache.OutputCache] constructor. Defaults to {}.

    Note:
        The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine.
    """
    self.async_llm_engine = async_llm_engine
    self.tokenizer = async_llm_engine.engine.get_tokenizer()
    self.request_counter = Counter()
    self.custom_sampler = DeferredSampler()
    self.cache = (
        OutputCache(maxsize=cache_size, **cache_opts)
        if cache_size > 0
        else None
    )

    async_llm_engine.engine.log_stats = False

    super().__init__(tokenizer=self.tokenizer)

from_name(model_name, engine_opts=None, **kwargs) classmethod

Create a AsyncVirtualLM instance from a model name.

Parameters:

Name Type Description Default
model_name str

Name of the model to load.

required
engine_opts dict

Additional options to pass to the AsyncLLMEngine. The engine will be configured with prefix caching enabled and async output processing disabled by default.

None
**kwargs

Additional arguments passed to AsyncVirtualLM constructor.

{}

Returns:

Type Description
AsyncVirtualLM

An AsyncVirtualLM instance.

Source code in genlm/backend/llm/vllm.py
@classmethod
def from_name(cls, model_name, engine_opts=None, **kwargs):
    """Create a `AsyncVirtualLM` instance from a model name.

    Args:
        model_name (str): Name of the model to load.
        engine_opts (dict): Additional options to pass to the `AsyncLLMEngine`. The engine will be
            configured with prefix caching enabled and async output processing disabled by default.
        **kwargs: Additional arguments passed to `AsyncVirtualLM` constructor.

    Returns:
        (AsyncVirtualLM): An `AsyncVirtualLM` instance.
    """
    if not HAS_VLLM:
        raise ImportError(  # pragma: no cover
            "vLLM not available. Install vLLM or use AsyncTransformer instead."
        )

    if engine_opts is not None and "enable_chunked_prefill" in engine_opts:
        if engine_opts["enable_chunked_prefill"]:
            warnings.warn(
                "Setting enable_chunked_prefill to True may interfere with AsyncVirtualLM's "
                "custom sampling functionality."
            )

    engine_opts = {
        "enable_prefix_caching": True,
        "disable_log_requests": True,
        "disable_async_output_proc": True,
        # Need to disable chunked prefill to avoid issues
        # with our custom sampler.
        "enable_chunked_prefill": False,
        **(engine_opts or {}),
    }

    engine = AsyncLLMEngine.from_engine_args(
        AsyncEngineArgs(model=model_name, tokenizer=model_name, **engine_opts)
    )

    return cls(engine, **kwargs)

next_token_logprobs(token_ids) async

Request log probabilities of next token asynchronously with output caching.

Parameters:

Name Type Description Default
token_ids_list list[int]

A list of token IDs, representing a prompt to the language model.

required

Returns:

Name Type Description
result Tensor

Normalized log probability tensor.

Warning

Do not use asyncio.run(next_token_logprobs()) as it may interfere with vLLM's background loop. For synchronous usage, use the next_token_logprobs_sync() method instead.

Source code in genlm/backend/llm/vllm.py
async def next_token_logprobs(self, token_ids):
    """Request log probabilities of next token asynchronously with output caching.

    Args:
        token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model.

    Returns:
        result (torch.Tensor): Normalized log probability tensor.

    Warning:
        Do not use `asyncio.run(next_token_logprobs())` as it may interfere with vLLM's background loop.
        For synchronous usage, use the `next_token_logprobs_sync()` method instead.
    """
    key = tuple(token_ids)

    if self.cache is not None and key in self.cache:
        return self.cache[key]

    result = await self._next_token_logprobs(key)

    if self.cache is not None:
        self.cache[key] = result

    return result

next_token_logprobs_sync(token_ids)

Request log probabilities of next token synchronously.

Parameters:

Name Type Description Default
token_ids_list list[int]

A list of token IDs, representing a prompt to the language model.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm/backend/llm/vllm.py
def next_token_logprobs_sync(self, token_ids):
    """Request log probabilities of next token synchronously.

    Args:
        token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    return self.batch_next_token_logprobs_sync([token_ids])[0]

batch_next_token_logprobs_sync(token_ids_list)

Request log probabilities of next tokens in a batch synchronously.

Parameters:

Name Type Description Default
token_ids_list list[list[int]]

A list of token ID lists, each representing a prompt to the language model.

required

Returns:

Type Description
Tensor

A tensor of normalized log probability tensors, one for each prompt in the input list.

Source code in genlm/backend/llm/vllm.py
def batch_next_token_logprobs_sync(self, token_ids_list):
    """
    Request log probabilities of next tokens in a batch synchronously.

    Args:
        token_ids_list (list[list[int]]): A list of token ID lists, each representing a prompt to the language model.

    Returns:
        (torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list.
    """
    req_ids = []
    for token_ids in token_ids_list:
        req_id = str(next(self.request_counter))
        req_ids.append(req_id)
        self.async_llm_engine.engine.add_request(
            prompt=TokensPrompt(prompt_token_ids=token_ids),
            params=self.default_params,
            request_id=req_id,
        )

    req_id2outputs = {}
    with self._optimized_sampling_context():
        while self.async_llm_engine.engine.has_unfinished_requests():
            output = self.async_llm_engine.engine.step()
            for out in output:
                if out.finished:
                    assert out.request_id not in req_id2outputs, (
                        f"Duplicate outputs for request {out.request_id}"
                    )
                    assert out.request_id in req_ids, (
                        f"{out.request_id} not in requested IDs"
                    )
                    req_id2outputs[out.request_id] = out

    logprobs = [
        self._validate_outputs([req_id2outputs[req_id]]) for req_id in req_ids
    ]

    return torch.stack(logprobs)

clear_cache()

Clear output cache.

Source code in genlm/backend/llm/vllm.py
def clear_cache(self):
    """Clear output cache."""
    if self.cache:
        self.cache.clear()

__del__()

Clean up resources on deletion.

Source code in genlm/backend/llm/vllm.py
def __del__(self):
    """Clean up resources on deletion."""
    self._cleanup_engine()

AsyncTransformer

Bases: AsyncLM

Asynchronous wrapper around a HuggingFace causal language model with caching support.

This class provides an asynchronous interface to HuggingFace language models with automatic batching and caching (output and KV) for improved efficiency.

Source code in genlm/backend/llm/hf.py
class AsyncTransformer(AsyncLM):
    """Asynchronous wrapper around a HuggingFace causal language model with caching support.

    This class provides an asynchronous interface to HuggingFace language models with automatic batching
    and caching (output and KV) for improved efficiency.
    """

    @classmethod
    def from_name(cls, model_id, bitsandbytes_opts=None, hf_opts=None, **kwargs):
        """Create an AsyncTransformer instance from a pretrained HuggingFace model.

        Args:
            model_id (str): Model identifier in HuggingFace's model hub.
            bitsandbytes_opts (dict, optional): Additional configuration options for bitsandbytes quantization.
                Defaults to None.
            hf_opts (dict, optional): Additional configuration options for loading the HuggingFace model.
                Defaults to None.
            **kwargs: Additional arguments passed to the `AsyncTransformer` constructor

        Returns:
            (AsyncTransformer): An initialized `AsyncTransformer` instance.
        """
        if bitsandbytes_opts:
            bnb_config = BitsAndBytesConfig(**bitsandbytes_opts)
        else:
            bnb_config = None

        _hf_opts = {
            "device_map": "auto",
            "torch_dtype": "auto",
        }
        if hf_opts:
            _hf_opts.update(hf_opts)

        tok = AutoTokenizer.from_pretrained(model_id)
        mod = AutoModelForCausalLM.from_pretrained(
            model_id, quantization_config=bnb_config, **_hf_opts
        )

        return cls(mod, tok, **kwargs)

    @torch.no_grad()
    def __init__(self, hf_model, hf_tokenizer, batch_size=20, timeout=0.02):
        """Initialize an AsyncTransformer instance.

        Args:
            hf_model: A HuggingFace CausalLM model instance.
            hf_tokenizer: A HuggingFace Tokenizer.
            batch_size (int, optional): Maximum queries to process in one batch during auto-batching.
                Defaults to 20.
            timeout (float, optional): Seconds to wait since last query before processing current batch.
                Defaults to 0.02.
        """
        self.model = hf_model
        self.tokenizer = hf_tokenizer
        self.device = hf_model.device
        self.cache = TokenTrie()

        # Queries to be batched. Each query is a sequence of tokens,
        # and a Future to be called when the query is resolved.
        self.queries = []
        self.batch_size = batch_size
        self.timeout = timeout
        self.timer = None

        self.model.eval()

        super().__init__(tokenizer=self.tokenizer)

    def clear_cache(self):
        """Clear the cache of log probabilities and key/value pairs."""
        self.cache = TokenTrie()

    def clear_kv_cache(self):
        """Clear any key and value vectors from the cache."""
        self.cache.clear_kv_cache()

    def reset_async_queries(self):
        """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing
        to completion."""
        self.queries = []

    @torch.no_grad()
    def cache_kv(self, prompt_tokens):
        """Cache the key and value vectors for a prompt. Future queries that have this prompt as a prefix will only run the LLM on new tokens.

        Args:
            prompt_tokens (list[int]): token ids for the prompt to cache.
        """
        result = self.model(torch.tensor([prompt_tokens]).to(self.device))
        node = self.cache.extend_cache(0, prompt_tokens, result.logits[0], 0)
        node.past_key_values = result.past_key_values

    @torch.no_grad()
    def batch_evaluate_queries(self):
        """
        Process a batch of queued language model queries.

        This method is called internally when the `batch_size` has been met or the `timeout` has expired.
        """

        queries, self.queries = self.queries, []
        if len(queries) == 0:
            return

        query_groups = defaultdict(list)
        for query in queries:
            key = tuple(query.prompt)  # XXX: cache based on past_len too?
            query_groups[key].append(query)

        # Use one representative query from each group
        unique_queries = [group[0] for group in query_groups.values()]

        past_example = next((q.past for q in unique_queries if q.past), False)
        max_past_length = max(q.past_len for q in unique_queries)
        max_query_length = max(len(q.prompt) for q in unique_queries)

        padding_token_id = (
            self.tokenizer.pad_token_id
            if self.tokenizer.pad_token_id is not None
            else 0
        )

        input_ids = torch.tensor(
            [
                q.prompt_padded(padding_token_id, max_query_length)
                for q in unique_queries
            ]
        ).to(self.device)
        attn_masks = torch.tensor(
            [
                q.attention_mask(max_past_length, max_query_length)
                for q in unique_queries
            ]
        ).to(self.device)
        posn_ids = torch.tensor(
            [q.position_ids(max_past_length, max_query_length) for q in unique_queries]
        ).to(self.device)
        if past_example:
            pasts = [
                [
                    torch.cat(
                        (
                            *(
                                q.past_padded(
                                    layer,
                                    j,
                                    max_past_length,
                                    past_example[0][0].dtype,
                                    self.device,
                                    past_example[0][0].shape,
                                )
                                for q in unique_queries
                            ),
                        ),
                        dim=0,
                    )
                    for j in range(2)
                ]
                for layer in range(len(past_example))
            ]
        else:
            pasts = None

        results = self.model(
            input_ids,
            attention_mask=attn_masks,
            position_ids=posn_ids,
            past_key_values=pasts,
            use_cache=pasts is not None,
        )

        assert len(results.logits) == len(unique_queries)

        for i, q in enumerate(unique_queries):
            result = results.logits[i]
            for dup_query in query_groups[tuple(q.prompt)]:
                dup_query.future.set_result(result)

    @torch.no_grad()
    def add_query(self, query, future, past):
        """Add a query to be evaluated in the next batch.

        This method is called internally when a `next_token_logprobs` request is made.

        Args:
            query (list[int]): Token IDs representing the query prompt
            future (asyncio.Future): Future to store the result in
            past (list[tuple[torch.Tensor]]|None): Past key/value states from previous evaluation,
                or None if this is a new query
        """
        self.queries.append(Query(query, future, past))

        if self.timer:
            self.timer.cancel()
            self.timer = None
        if len(self.queries) >= self.batch_size:
            self.batch_evaluate_queries()
        else:
            self.timer = asyncio.get_running_loop().call_later(
                self.timeout, lambda: self.batch_evaluate_queries()
            )

    def walk_cache(self, token_ids):
        """Walk the cache tree to find the deepest node matching a sequence of tokens.

        Args:
            token_ids (list[int]): Sequence of token IDs to follow in the cache tree

        Returns:
            tuple:
                - CacheNode: The deepest node in the cache tree that matches the token sequence
                - int: Number of tokens matched from the start of token_ids
                - list[tuple[torch.Tensor]]|None: Past key/value states from the deepest cached node,
                    or None if no cached states were found
                - int: Base index indicating where the past states start in token_ids
        """
        # Walk while tokens can be found
        node = self.cache
        next_token_index = 0

        past = None
        base = 0
        while next_token_index < len(token_ids):
            if node.past_key_values is not None:
                past = node.past_key_values
                base = next_token_index
            if node.has_token(token_ids[next_token_index]):
                node = node.get_token(token_ids[next_token_index])
                next_token_index += 1
            else:
                break

        return node, next_token_index, past, base

    @torch.no_grad()
    async def next_token_logprobs(self, token_ids):
        """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`.

        Args:
            token_ids (list[int]): a list of token ids, representing a prompt to the language model.

        Returns:
            logprobs (torch.Tensor): a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        node, next_token_index, past, base = self.walk_cache(token_ids)

        # If we processed all tokens, then we're done.
        if next_token_index == len(token_ids):
            return node.logprobs

        # Create a future with the prompt
        future = asyncio.get_running_loop().create_future()
        self.add_query(token_ids[base:], future, past)
        logits = await future

        # Create new nodes
        node = node.extend_cache(next_token_index, token_ids, logits, base)

        return node.logprobs

    @torch.no_grad()
    def next_token_logprobs_sync(self, token_ids):
        """Request log probabilities of next token. Not asynchronous, and does not support auto-batching.

        Args:
            token_ids (list[int]): a list of token ids, representing a prompt to the language model.

        Returns:
            logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        # Walk while tokens can be found
        node, next_token_index, past, base = self.walk_cache(token_ids)

        if next_token_index == len(token_ids):
            return node.logprobs

        logits = self.model(
            torch.tensor([token_ids[base:]]).to(self.device),
            past_key_values=node.past_key_values,
            use_cache=node.past_key_values is not None,
        ).logits[0]

        node = node.extend_cache(next_token_index, token_ids, logits, base)

        return node.logprobs

    def next_token_logprobs_uncached(self, token_ids):
        """Request log probabilities of next token. No KV or output caching, and does not support auto-batching.

        Args:
            token_ids (list[int]): a list of token ids, representing a prompt to the language model.

        Returns:
            logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        with torch.no_grad():
            logits = self.model(
                torch.tensor([token_ids]).to(self.device),
                past_key_values=None,
                use_cache=False,
            ).logits[0]
            return torch.log_softmax(logits[-1], dim=0)

from_name(model_id, bitsandbytes_opts=None, hf_opts=None, **kwargs) classmethod

Create an AsyncTransformer instance from a pretrained HuggingFace model.

Parameters:

Name Type Description Default
model_id str

Model identifier in HuggingFace's model hub.

required
bitsandbytes_opts dict

Additional configuration options for bitsandbytes quantization. Defaults to None.

None
hf_opts dict

Additional configuration options for loading the HuggingFace model. Defaults to None.

None
**kwargs

Additional arguments passed to the AsyncTransformer constructor

{}

Returns:

Type Description
AsyncTransformer

An initialized AsyncTransformer instance.

Source code in genlm/backend/llm/hf.py
@classmethod
def from_name(cls, model_id, bitsandbytes_opts=None, hf_opts=None, **kwargs):
    """Create an AsyncTransformer instance from a pretrained HuggingFace model.

    Args:
        model_id (str): Model identifier in HuggingFace's model hub.
        bitsandbytes_opts (dict, optional): Additional configuration options for bitsandbytes quantization.
            Defaults to None.
        hf_opts (dict, optional): Additional configuration options for loading the HuggingFace model.
            Defaults to None.
        **kwargs: Additional arguments passed to the `AsyncTransformer` constructor

    Returns:
        (AsyncTransformer): An initialized `AsyncTransformer` instance.
    """
    if bitsandbytes_opts:
        bnb_config = BitsAndBytesConfig(**bitsandbytes_opts)
    else:
        bnb_config = None

    _hf_opts = {
        "device_map": "auto",
        "torch_dtype": "auto",
    }
    if hf_opts:
        _hf_opts.update(hf_opts)

    tok = AutoTokenizer.from_pretrained(model_id)
    mod = AutoModelForCausalLM.from_pretrained(
        model_id, quantization_config=bnb_config, **_hf_opts
    )

    return cls(mod, tok, **kwargs)

__init__(hf_model, hf_tokenizer, batch_size=20, timeout=0.02)

Initialize an AsyncTransformer instance.

Parameters:

Name Type Description Default
hf_model

A HuggingFace CausalLM model instance.

required
hf_tokenizer

A HuggingFace Tokenizer.

required
batch_size int

Maximum queries to process in one batch during auto-batching. Defaults to 20.

20
timeout float

Seconds to wait since last query before processing current batch. Defaults to 0.02.

0.02
Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def __init__(self, hf_model, hf_tokenizer, batch_size=20, timeout=0.02):
    """Initialize an AsyncTransformer instance.

    Args:
        hf_model: A HuggingFace CausalLM model instance.
        hf_tokenizer: A HuggingFace Tokenizer.
        batch_size (int, optional): Maximum queries to process in one batch during auto-batching.
            Defaults to 20.
        timeout (float, optional): Seconds to wait since last query before processing current batch.
            Defaults to 0.02.
    """
    self.model = hf_model
    self.tokenizer = hf_tokenizer
    self.device = hf_model.device
    self.cache = TokenTrie()

    # Queries to be batched. Each query is a sequence of tokens,
    # and a Future to be called when the query is resolved.
    self.queries = []
    self.batch_size = batch_size
    self.timeout = timeout
    self.timer = None

    self.model.eval()

    super().__init__(tokenizer=self.tokenizer)

clear_cache()

Clear the cache of log probabilities and key/value pairs.

Source code in genlm/backend/llm/hf.py
def clear_cache(self):
    """Clear the cache of log probabilities and key/value pairs."""
    self.cache = TokenTrie()

clear_kv_cache()

Clear any key and value vectors from the cache.

Source code in genlm/backend/llm/hf.py
def clear_kv_cache(self):
    """Clear any key and value vectors from the cache."""
    self.cache.clear_kv_cache()

reset_async_queries()

Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing to completion.

Source code in genlm/backend/llm/hf.py
def reset_async_queries(self):
    """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing
    to completion."""
    self.queries = []

cache_kv(prompt_tokens)

Cache the key and value vectors for a prompt. Future queries that have this prompt as a prefix will only run the LLM on new tokens.

Parameters:

Name Type Description Default
prompt_tokens list[int]

token ids for the prompt to cache.

required
Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def cache_kv(self, prompt_tokens):
    """Cache the key and value vectors for a prompt. Future queries that have this prompt as a prefix will only run the LLM on new tokens.

    Args:
        prompt_tokens (list[int]): token ids for the prompt to cache.
    """
    result = self.model(torch.tensor([prompt_tokens]).to(self.device))
    node = self.cache.extend_cache(0, prompt_tokens, result.logits[0], 0)
    node.past_key_values = result.past_key_values

batch_evaluate_queries()

Process a batch of queued language model queries.

This method is called internally when the batch_size has been met or the timeout has expired.

Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def batch_evaluate_queries(self):
    """
    Process a batch of queued language model queries.

    This method is called internally when the `batch_size` has been met or the `timeout` has expired.
    """

    queries, self.queries = self.queries, []
    if len(queries) == 0:
        return

    query_groups = defaultdict(list)
    for query in queries:
        key = tuple(query.prompt)  # XXX: cache based on past_len too?
        query_groups[key].append(query)

    # Use one representative query from each group
    unique_queries = [group[0] for group in query_groups.values()]

    past_example = next((q.past for q in unique_queries if q.past), False)
    max_past_length = max(q.past_len for q in unique_queries)
    max_query_length = max(len(q.prompt) for q in unique_queries)

    padding_token_id = (
        self.tokenizer.pad_token_id
        if self.tokenizer.pad_token_id is not None
        else 0
    )

    input_ids = torch.tensor(
        [
            q.prompt_padded(padding_token_id, max_query_length)
            for q in unique_queries
        ]
    ).to(self.device)
    attn_masks = torch.tensor(
        [
            q.attention_mask(max_past_length, max_query_length)
            for q in unique_queries
        ]
    ).to(self.device)
    posn_ids = torch.tensor(
        [q.position_ids(max_past_length, max_query_length) for q in unique_queries]
    ).to(self.device)
    if past_example:
        pasts = [
            [
                torch.cat(
                    (
                        *(
                            q.past_padded(
                                layer,
                                j,
                                max_past_length,
                                past_example[0][0].dtype,
                                self.device,
                                past_example[0][0].shape,
                            )
                            for q in unique_queries
                        ),
                    ),
                    dim=0,
                )
                for j in range(2)
            ]
            for layer in range(len(past_example))
        ]
    else:
        pasts = None

    results = self.model(
        input_ids,
        attention_mask=attn_masks,
        position_ids=posn_ids,
        past_key_values=pasts,
        use_cache=pasts is not None,
    )

    assert len(results.logits) == len(unique_queries)

    for i, q in enumerate(unique_queries):
        result = results.logits[i]
        for dup_query in query_groups[tuple(q.prompt)]:
            dup_query.future.set_result(result)

add_query(query, future, past)

Add a query to be evaluated in the next batch.

This method is called internally when a next_token_logprobs request is made.

Parameters:

Name Type Description Default
query list[int]

Token IDs representing the query prompt

required
future Future

Future to store the result in

required
past list[tuple[Tensor]] | None

Past key/value states from previous evaluation, or None if this is a new query

required
Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def add_query(self, query, future, past):
    """Add a query to be evaluated in the next batch.

    This method is called internally when a `next_token_logprobs` request is made.

    Args:
        query (list[int]): Token IDs representing the query prompt
        future (asyncio.Future): Future to store the result in
        past (list[tuple[torch.Tensor]]|None): Past key/value states from previous evaluation,
            or None if this is a new query
    """
    self.queries.append(Query(query, future, past))

    if self.timer:
        self.timer.cancel()
        self.timer = None
    if len(self.queries) >= self.batch_size:
        self.batch_evaluate_queries()
    else:
        self.timer = asyncio.get_running_loop().call_later(
            self.timeout, lambda: self.batch_evaluate_queries()
        )

walk_cache(token_ids)

Walk the cache tree to find the deepest node matching a sequence of tokens.

Parameters:

Name Type Description Default
token_ids list[int]

Sequence of token IDs to follow in the cache tree

required

Returns:

Name Type Description
tuple
  • CacheNode: The deepest node in the cache tree that matches the token sequence
  • int: Number of tokens matched from the start of token_ids
  • list[tuple[torch.Tensor]]|None: Past key/value states from the deepest cached node, or None if no cached states were found
  • int: Base index indicating where the past states start in token_ids
Source code in genlm/backend/llm/hf.py
def walk_cache(self, token_ids):
    """Walk the cache tree to find the deepest node matching a sequence of tokens.

    Args:
        token_ids (list[int]): Sequence of token IDs to follow in the cache tree

    Returns:
        tuple:
            - CacheNode: The deepest node in the cache tree that matches the token sequence
            - int: Number of tokens matched from the start of token_ids
            - list[tuple[torch.Tensor]]|None: Past key/value states from the deepest cached node,
                or None if no cached states were found
            - int: Base index indicating where the past states start in token_ids
    """
    # Walk while tokens can be found
    node = self.cache
    next_token_index = 0

    past = None
    base = 0
    while next_token_index < len(token_ids):
        if node.past_key_values is not None:
            past = node.past_key_values
            base = next_token_index
        if node.has_token(token_ids[next_token_index]):
            node = node.get_token(token_ids[next_token_index])
            next_token_index += 1
        else:
            break

    return node, next_token_index, past, base

next_token_logprobs(token_ids) async

Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with await.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs Tensor

a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in genlm/backend/llm/hf.py
@torch.no_grad()
async def next_token_logprobs(self, token_ids):
    """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`.

    Args:
        token_ids (list[int]): a list of token ids, representing a prompt to the language model.

    Returns:
        logprobs (torch.Tensor): a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    node, next_token_index, past, base = self.walk_cache(token_ids)

    # If we processed all tokens, then we're done.
    if next_token_index == len(token_ids):
        return node.logprobs

    # Create a future with the prompt
    future = asyncio.get_running_loop().create_future()
    self.add_query(token_ids[base:], future, past)
    logits = await future

    # Create new nodes
    node = node.extend_cache(next_token_index, token_ids, logits, base)

    return node.logprobs

next_token_logprobs_sync(token_ids)

Request log probabilities of next token. Not asynchronous, and does not support auto-batching.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs Tensor

a tensor with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def next_token_logprobs_sync(self, token_ids):
    """Request log probabilities of next token. Not asynchronous, and does not support auto-batching.

    Args:
        token_ids (list[int]): a list of token ids, representing a prompt to the language model.

    Returns:
        logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    # Walk while tokens can be found
    node, next_token_index, past, base = self.walk_cache(token_ids)

    if next_token_index == len(token_ids):
        return node.logprobs

    logits = self.model(
        torch.tensor([token_ids[base:]]).to(self.device),
        past_key_values=node.past_key_values,
        use_cache=node.past_key_values is not None,
    ).logits[0]

    node = node.extend_cache(next_token_index, token_ids, logits, base)

    return node.logprobs

next_token_logprobs_uncached(token_ids)

Request log probabilities of next token. No KV or output caching, and does not support auto-batching.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs Tensor

a tensor with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in genlm/backend/llm/hf.py
def next_token_logprobs_uncached(self, token_ids):
    """Request log probabilities of next token. No KV or output caching, and does not support auto-batching.

    Args:
        token_ids (list[int]): a list of token ids, representing a prompt to the language model.

    Returns:
        logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    with torch.no_grad():
        logits = self.model(
            torch.tensor([token_ids]).to(self.device),
            past_key_values=None,
            use_cache=False,
        ).logits[0]
        return torch.log_softmax(logits[-1], dim=0)

load_model_by_name(name, backend=None, llm_opts=None)

Load a language model by name.

Parameters:

Name Type Description Default
name str

Hugging Face model name (e.g. "gpt2", "meta-llama/Llama-3.2-1B-Instruct")

required
backend str

Backend to use for inference. Can be "vllm" or "hf". If None, defaults to "vllm" if CUDA is available, otherwise "hf".

None
llm_opts dict

Additional options to pass to the backend constructor. See AsyncVirtualLM and AsyncTransformer documentation for details.

None

Returns:

Type Description
AsyncLM

An asynchronous language model.

Raises:

Type Description
ValueError

If an invalid backend is specified.

Source code in genlm/backend/llm/__init__.py
def load_model_by_name(name, backend=None, llm_opts=None):
    """Load a language model by name.

    Args:
        name (str): Hugging Face model name (e.g. "gpt2", "meta-llama/Llama-3.2-1B-Instruct")
        backend (str, optional): Backend to use for inference. Can be "vllm" or "hf".
            If None, defaults to "vllm" if CUDA is available, otherwise "hf".
        llm_opts (dict, optional): Additional options to pass to the backend constructor.
            See AsyncVirtualLM and AsyncTransformer documentation for details.

    Returns:
        (AsyncLM): An asynchronous language model.

    Raises:
        (ValueError): If an invalid backend is specified.
    """
    if backend is None:
        backend = "vllm" if torch.cuda.is_available() else "hf"

    if llm_opts is None:
        llm_opts = {}

    if backend == "vllm":
        return AsyncVirtualLM.from_name(name, **llm_opts)
    elif backend == "hf":
        return AsyncTransformer.from_name(name, **llm_opts)
    else:
        raise ValueError(f"Invalid backend: {backend}")

decode_vocab(tokenizer, byte2str_fallback='tokenizer')

Convert tokenizer vocabulary into byte and string representations.

Warning

The byte representation is the canonical form. The string representation is provided for convenience but may not decode properly for all tokens, especially those containing invalid UTF-8 sequences.

Parameters:

Name Type Description Default
tokenizer

A Hugging Face tokenizer instance

required
byte2str_fallback str

Strategy for converting invalid UTF-8 bytes to strings. Options:

  • 'tokenizer': Use tokenizer's convert_ids_to_tokens (default)
  • 'latin1': Decode using latin1 encoding
  • 'replace': Use Unicode replacement character '�'
'tokenizer'

Returns:

Type Description
tuple

(byte_vocab, str_vocab)

Source code in genlm/backend/tokenization/vocab.py
def decode_vocab(tokenizer, byte2str_fallback="tokenizer"):
    """Convert tokenizer vocabulary into byte and string representations.

    Warning:
        The byte representation is the canonical form. The string representation is provided for
        convenience but may not decode properly for all tokens, especially those containing invalid UTF-8 sequences.

    Args:
        tokenizer: A Hugging Face tokenizer instance
        byte2str_fallback (str): Strategy for converting invalid UTF-8 bytes to strings. Options:\n
            - 'tokenizer': Use tokenizer's `convert_ids_to_tokens` (default)
            - 'latin1': Decode using latin1 encoding
            - 'replace': Use Unicode replacement character '�'

    Returns:
        (tuple): (byte_vocab, str_vocab)
    """
    if byte2str_fallback not in ["latin1", "tokenizer", "replace"]:
        raise ValueError(f"Unknown byte2str_fallback strategy: {byte2str_fallback}")

    if tokenizer.is_fast:
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer.name_or_path, use_fast=False
        )

    # Try slow tokenizer.
    try:
        byte_vocab = get_byte_vocab(tokenizer)
    except ByteVocabError:
        # warnings.warn("Could not decode vocabulary from slow tokenizer. Trying using fast tokenizer.")

        # Try fast tokenizer.
        tokenizer = AutoTokenizer.from_pretrained(tokenizer.name_or_path, use_fast=True)
        try:
            byte_vocab = get_byte_vocab(tokenizer)
        except ByteVocabError as e:
            raise ValueError(
                f"Could not decode byte representation of token vocabuary from tokenizer {tokenizer.name_or_path}"
            ) from e

    str_vocab = bytes_to_strs(tokenizer, byte_vocab, byte2str_fallback)

    return byte_vocab, str_vocab

TokenCharacterTrie

A trie data structure for efficient token-to-character mapping.

Source code in genlm/backend/trie/base.py
class TokenCharacterTrie:
    """A trie data structure for efficient token-to-character mapping."""

    def __init__(self, decode):
        """Initialize a `TokenCharacterTrie`.

        Args:
            decode (list): List representing the token vocabulary.
                Each element of the list must be iterable.
        """
        self.decode = decode
        self.word2leaf = {}
        self.children = [{}]  # First node is root
        self.root = 0
        self.token_id_to_leaf = []

        for token_id, word in enumerate(self.decode):
            curr = self.root
            for letter in word:
                if letter not in self.children[curr]:
                    self.children[curr][letter] = len(self.children)
                    self.children.append({})
                curr = self.children[curr][letter]

            self.children[curr][None] = last = len(self.children)
            self.children.append({})
            assert word not in self.word2leaf, (
                "Can't have duplicate words in vocabulary"
            )
            self.word2leaf[word] = last

            self.token_id_to_leaf.append((token_id, last))

        self.leaf2word = dict(zip(self.word2leaf.values(), self.word2leaf.keys()))
        self.jump = List(
            [np.array(sorted(x.values()), dtype=np.int32) for x in self.children]
        )
        self.ordering = np.array(list(self._order(self.root)), np.int32)

        # Renumber the states of the trie so that they are named by a contiguous
        # range of integers and those integers respect the are topologically
        # ordering of the trie topology.  This improves the efficiency of the
        # updating the trie as it improves memory locality.
        ordering = {}
        for i, x in enumerate(self._order_full(self.root)):
            ordering[x] = i
        self._rename(f=lambda x: ordering[x])

        node2prefix = {self.root: []}
        for x in reversed(range(len(self.children))):
            for letter, y in self.children[x].items():
                if letter is None:
                    node2prefix[y] = node2prefix[x]
                else:
                    node2prefix[y] = node2prefix[x] + [letter]
        self.node2prefix = node2prefix

    def _rename(self, f):
        """Rename all node indices in the trie using the provided mapping function.

        Args:
            f (callable): Function that maps old node indices to new node indices
        """
        N = len(self.children)

        new_children = [{} for _ in range(N)]
        nodes = range(N)

        for x in nodes:
            for letter, y in self.children[x].items():
                new_children[f(x)][letter] = f(y)

        self.root = f(self.root)
        self.children = new_children
        self.word2leaf = {w: f(x) for w, x in self.word2leaf.items()}
        self.leaf2word = dict(zip(self.word2leaf.values(), self.word2leaf.keys()))

        self.token_id_to_leaf = np.array(
            [(i, f(x)) for i, x in self.token_id_to_leaf], dtype=np.int32
        )

        self.ordering = np.array([f(x) for x in self.ordering])
        self.jump = List(
            [np.array(sorted(x.values()), dtype=np.int32) for x in new_children]
        )

    def _alloc_weights(self):
        """Allocate an array to store weight values for all nodes.

        Returns:
            np.ndarray: Zero-initialized array for storing weight values
        """
        return np.zeros(len(self.children), dtype=np.float64)

    def _preprocess_ws(self, ws):
        """Preprocess the weight vector to ensure it is a numpy array and on the correct device.

        Args:
            ws (torch.Tensor|np.ndarray): Token weights over the vocabulary of shape `(len(self.decode),)`

        Returns:
            (np.ndarray): Weight vector
        """
        if isinstance(ws, torch.Tensor):
            if ws.device.type != "cpu":
                ws = ws.cpu()
            ws = ws.numpy()
        return ws

    def weight_sum(self, ws):
        """Compute weight sum for each node in the trie.

        For each node in the trie, this computes the sum of weights of all leaf nodes (tokens)
        that are descendants of that node.

        Args:
            ws (torch.Tensor|np.ndarray): Token weights over the vocabulary of shape `(len(self.decode),)`

        Returns:
            (np.ndarray): Summed weights for each node in the trie.
        """
        ws = self._preprocess_ws(ws)
        node_ws = self._alloc_weights()
        _update_trie_numba_sum(
            node_ws=node_ws,
            ws=ws,
            token_id_to_leaf=self.token_id_to_leaf,
            jump=self.jump,
            ordering=self.ordering,
        )
        return node_ws

    def weight_max(self, ws):
        """Compute weight max for each node in the trie.

        For each node in the trie, this computes the maximum weight among all leaf nodes (tokens)
        that are descendants of that node.

        Args:
            ws (torch.Tensor|np.ndarray): Token weights over the vocabulary of shape `(len(self.decode),)`

        Returns:
            (np.ndarray): Weight max values for each node in the trie.
        """
        ws = self._preprocess_ws(ws)
        node_ws = self._alloc_weights()
        _update_trie_numba_max(
            node_ws=node_ws,
            ws=ws,
            token_id_to_leaf=self.token_id_to_leaf,
            jump=self.jump,
            ordering=self.ordering,
        )
        return node_ws

    def batch_weight_sum(self, ws):
        """Batched equivalent of `weight_sum`.

        Args:
            ws (list[torch.Tensor|np.ndarray]): Batch of token weights, each of shape `(len(self.decode),)`

        Returns:
            (np.ndarray): Batch of weight values of `len(ws)` for each node in the trie
        """
        return np.array([self.weight_sum(ws) for ws in ws])

    def batch_weight_max(self, ws):
        """Batched equivalent of `weight_max`.

        Args:
            ws (list[torch.Tensor|np.ndarray]): Batch of token weights, each of shape `(len(self.decode),)`

        Returns:
            (np.ndarray): Batch of weight max values of `len(ws)` for each node in the trie
        """
        return np.array([self.weight_max(ws) for ws in ws])

    def _order(self, node):
        """Generate a topological ordering of nodes beneath the given node.

        Args:
            node (int): Starting node index

        Yields:
            int: Node indices in topological order
        """
        for a in self.children[node]:
            if a is None:
                pass
            else:
                yield from self._order(self.children[node][a])
        yield node

    def _order_full(self, node):
        """Generate a complete topological ordering including all child nodes.

        Args:
            node (int): Starting node index

        Yields:
            (int): Node indices in complete topological order
        """
        for a in self.children[node]:
            yield from self._order_full(self.children[node][a])
        yield node

    def visualize(self, ws=None):
        """Visualize the trie structure using Graphviz.

        Args:
            ws (np.ndarray|None): Optional weight vector to display at each node.
                                Should be of length `len(self.children)`.

        Returns:
            (graphviz.Digraph): The generated graph object
        """
        try:
            import graphviz
        except ImportError:  # pragma: no cover
            raise ImportError(
                "Please install graphviz: pip install graphviz"
            )  # pragma: no cover

        if ws is not None and len(ws) != len(self.children):
            raise ValueError(
                f"Weight vector length ({len(ws)}) must match number of nodes ({len(self.children)})"
            )

        dot = graphviz.Digraph(comment="Token Character Trie")
        dot.attr(rankdir="LR")

        # Create a subgraph for the legend
        with dot.subgraph(name="cluster_legend") as legend:
            legend.attr(label="Legend", fontsize="10")
            legend.attr("node", fontsize="7", width="0.1", height="0.1")

            # Example internal node
            legend.node(
                "legend_internal",
                "Internal Node ID\n'Prefix'\nWeight (if provided)",
                shape="circle",
            )

            # Example leaf node
            legend.node("legend_leaf", "Complete Token", shape="doublecircle")

            legend.edge(
                "legend_internal",
                "legend_leaf",
                label="Token item",
                fontsize="10",
            )

            # Align legend horizontally
            legend.attr(rankdir="TB")
            legend.attr(rank="same")

        # Add the main trie nodes and edges
        for node_id in range(len(self.children)):
            prefix = self.node2prefix[node_id]

            if ws is not None:
                label = f"{node_id}\n'{prefix}'\n{ws[node_id]:.4f}"
            else:
                label = f"{node_id}\n'{prefix}'"

            # Color nodes based on mass if provided
            if ws is not None:
                max_ws = ws.max()
                if max_ws > 0:
                    intensity = int(255 * (1 - ws[node_id] / max_ws))
                    color = f"#{intensity:02x}{255:02x}{intensity:02x}"
                else:
                    color = "#ffffff"  # white for zero mass
            else:
                color = "#ffffff"  # default white

            if node_id in self.leaf2word:
                dot.node(
                    str(node_id),
                    label,
                    shape="doublecircle",
                    style="filled",
                    fillcolor=color,
                )
            else:
                dot.node(
                    str(node_id), label, shape="circle", style="filled", fillcolor=color
                )

        for node_id, children in enumerate(self.children):
            for char, child_id in children.items():
                if char is not None:
                    edge_label = str(char)
                else:
                    edge_label = "End-of-Token"

                dot.edge(str(node_id), str(child_id), label=edge_label)

        return dot

__init__(decode)

Initialize a TokenCharacterTrie.

Parameters:

Name Type Description Default
decode list

List representing the token vocabulary. Each element of the list must be iterable.

required
Source code in genlm/backend/trie/base.py
def __init__(self, decode):
    """Initialize a `TokenCharacterTrie`.

    Args:
        decode (list): List representing the token vocabulary.
            Each element of the list must be iterable.
    """
    self.decode = decode
    self.word2leaf = {}
    self.children = [{}]  # First node is root
    self.root = 0
    self.token_id_to_leaf = []

    for token_id, word in enumerate(self.decode):
        curr = self.root
        for letter in word:
            if letter not in self.children[curr]:
                self.children[curr][letter] = len(self.children)
                self.children.append({})
            curr = self.children[curr][letter]

        self.children[curr][None] = last = len(self.children)
        self.children.append({})
        assert word not in self.word2leaf, (
            "Can't have duplicate words in vocabulary"
        )
        self.word2leaf[word] = last

        self.token_id_to_leaf.append((token_id, last))

    self.leaf2word = dict(zip(self.word2leaf.values(), self.word2leaf.keys()))
    self.jump = List(
        [np.array(sorted(x.values()), dtype=np.int32) for x in self.children]
    )
    self.ordering = np.array(list(self._order(self.root)), np.int32)

    # Renumber the states of the trie so that they are named by a contiguous
    # range of integers and those integers respect the are topologically
    # ordering of the trie topology.  This improves the efficiency of the
    # updating the trie as it improves memory locality.
    ordering = {}
    for i, x in enumerate(self._order_full(self.root)):
        ordering[x] = i
    self._rename(f=lambda x: ordering[x])

    node2prefix = {self.root: []}
    for x in reversed(range(len(self.children))):
        for letter, y in self.children[x].items():
            if letter is None:
                node2prefix[y] = node2prefix[x]
            else:
                node2prefix[y] = node2prefix[x] + [letter]
    self.node2prefix = node2prefix

weight_sum(ws)

Compute weight sum for each node in the trie.

For each node in the trie, this computes the sum of weights of all leaf nodes (tokens) that are descendants of that node.

Parameters:

Name Type Description Default
ws Tensor | ndarray

Token weights over the vocabulary of shape (len(self.decode),)

required

Returns:

Type Description
ndarray

Summed weights for each node in the trie.

Source code in genlm/backend/trie/base.py
def weight_sum(self, ws):
    """Compute weight sum for each node in the trie.

    For each node in the trie, this computes the sum of weights of all leaf nodes (tokens)
    that are descendants of that node.

    Args:
        ws (torch.Tensor|np.ndarray): Token weights over the vocabulary of shape `(len(self.decode),)`

    Returns:
        (np.ndarray): Summed weights for each node in the trie.
    """
    ws = self._preprocess_ws(ws)
    node_ws = self._alloc_weights()
    _update_trie_numba_sum(
        node_ws=node_ws,
        ws=ws,
        token_id_to_leaf=self.token_id_to_leaf,
        jump=self.jump,
        ordering=self.ordering,
    )
    return node_ws

weight_max(ws)

Compute weight max for each node in the trie.

For each node in the trie, this computes the maximum weight among all leaf nodes (tokens) that are descendants of that node.

Parameters:

Name Type Description Default
ws Tensor | ndarray

Token weights over the vocabulary of shape (len(self.decode),)

required

Returns:

Type Description
ndarray

Weight max values for each node in the trie.

Source code in genlm/backend/trie/base.py
def weight_max(self, ws):
    """Compute weight max for each node in the trie.

    For each node in the trie, this computes the maximum weight among all leaf nodes (tokens)
    that are descendants of that node.

    Args:
        ws (torch.Tensor|np.ndarray): Token weights over the vocabulary of shape `(len(self.decode),)`

    Returns:
        (np.ndarray): Weight max values for each node in the trie.
    """
    ws = self._preprocess_ws(ws)
    node_ws = self._alloc_weights()
    _update_trie_numba_max(
        node_ws=node_ws,
        ws=ws,
        token_id_to_leaf=self.token_id_to_leaf,
        jump=self.jump,
        ordering=self.ordering,
    )
    return node_ws

batch_weight_sum(ws)

Batched equivalent of weight_sum.

Parameters:

Name Type Description Default
ws list[Tensor | ndarray]

Batch of token weights, each of shape (len(self.decode),)

required

Returns:

Type Description
ndarray

Batch of weight values of len(ws) for each node in the trie

Source code in genlm/backend/trie/base.py
def batch_weight_sum(self, ws):
    """Batched equivalent of `weight_sum`.

    Args:
        ws (list[torch.Tensor|np.ndarray]): Batch of token weights, each of shape `(len(self.decode),)`

    Returns:
        (np.ndarray): Batch of weight values of `len(ws)` for each node in the trie
    """
    return np.array([self.weight_sum(ws) for ws in ws])

batch_weight_max(ws)

Batched equivalent of weight_max.

Parameters:

Name Type Description Default
ws list[Tensor | ndarray]

Batch of token weights, each of shape (len(self.decode),)

required

Returns:

Type Description
ndarray

Batch of weight max values of len(ws) for each node in the trie

Source code in genlm/backend/trie/base.py
def batch_weight_max(self, ws):
    """Batched equivalent of `weight_max`.

    Args:
        ws (list[torch.Tensor|np.ndarray]): Batch of token weights, each of shape `(len(self.decode),)`

    Returns:
        (np.ndarray): Batch of weight max values of `len(ws)` for each node in the trie
    """
    return np.array([self.weight_max(ws) for ws in ws])

visualize(ws=None)

Visualize the trie structure using Graphviz.

Parameters:

Name Type Description Default
ws ndarray | None

Optional weight vector to display at each node. Should be of length len(self.children).

None

Returns:

Type Description
Digraph

The generated graph object

Source code in genlm/backend/trie/base.py
def visualize(self, ws=None):
    """Visualize the trie structure using Graphviz.

    Args:
        ws (np.ndarray|None): Optional weight vector to display at each node.
                            Should be of length `len(self.children)`.

    Returns:
        (graphviz.Digraph): The generated graph object
    """
    try:
        import graphviz
    except ImportError:  # pragma: no cover
        raise ImportError(
            "Please install graphviz: pip install graphviz"
        )  # pragma: no cover

    if ws is not None and len(ws) != len(self.children):
        raise ValueError(
            f"Weight vector length ({len(ws)}) must match number of nodes ({len(self.children)})"
        )

    dot = graphviz.Digraph(comment="Token Character Trie")
    dot.attr(rankdir="LR")

    # Create a subgraph for the legend
    with dot.subgraph(name="cluster_legend") as legend:
        legend.attr(label="Legend", fontsize="10")
        legend.attr("node", fontsize="7", width="0.1", height="0.1")

        # Example internal node
        legend.node(
            "legend_internal",
            "Internal Node ID\n'Prefix'\nWeight (if provided)",
            shape="circle",
        )

        # Example leaf node
        legend.node("legend_leaf", "Complete Token", shape="doublecircle")

        legend.edge(
            "legend_internal",
            "legend_leaf",
            label="Token item",
            fontsize="10",
        )

        # Align legend horizontally
        legend.attr(rankdir="TB")
        legend.attr(rank="same")

    # Add the main trie nodes and edges
    for node_id in range(len(self.children)):
        prefix = self.node2prefix[node_id]

        if ws is not None:
            label = f"{node_id}\n'{prefix}'\n{ws[node_id]:.4f}"
        else:
            label = f"{node_id}\n'{prefix}'"

        # Color nodes based on mass if provided
        if ws is not None:
            max_ws = ws.max()
            if max_ws > 0:
                intensity = int(255 * (1 - ws[node_id] / max_ws))
                color = f"#{intensity:02x}{255:02x}{intensity:02x}"
            else:
                color = "#ffffff"  # white for zero mass
        else:
            color = "#ffffff"  # default white

        if node_id in self.leaf2word:
            dot.node(
                str(node_id),
                label,
                shape="doublecircle",
                style="filled",
                fillcolor=color,
            )
        else:
            dot.node(
                str(node_id), label, shape="circle", style="filled", fillcolor=color
            )

    for node_id, children in enumerate(self.children):
        for char, child_id in children.items():
            if char is not None:
                edge_label = str(char)
            else:
                edge_label = "End-of-Token"

            dot.edge(str(node_id), str(child_id), label=edge_label)

    return dot

ParallelTokenCharacterTrie

Bases: TokenCharacterTrie

A GPU-optimized version of TokenCharacterTrie that performs weight sum and max operations in parallel.

Source code in genlm/backend/trie/parallel.py
class ParallelTokenCharacterTrie(TokenCharacterTrie):
    """A GPU-optimized version of `TokenCharacterTrie` that performs weight sum and max operations in parallel."""

    def __init__(self, decode, device=None, **kwargs):
        super().__init__(decode, **kwargs)

        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        if self.device not in ["cpu", "cuda"]:
            raise ValueError(f"Invalid device: {device}. Must be 'cpu', 'cuda' or None")

        self._build_reachability_matrix()
        self.token_ids = torch.tensor(
            self.token_id_to_leaf[:, 0], dtype=torch.long, device=self.device
        )

    def _build_parent_map(self):
        """Builds a mapping from each node to its parent node in the trie.

        Returns:
            (dict): A dictionary where keys are child nodes and values are their parent nodes.
        """
        parent = {}
        for node in range(len(self.children)):
            for child in self.jump[node]:
                parent[child] = node
        return parent

    def _build_reachability_matrix(self):
        """Constructs a sparse reachability matrix for efficient weight propagation.

        The matrix M is constructed such that M[i,j] = 1 if node j is either:
        - The leaf node i itself (self-connection)
        - An ancestor of leaf node i in the trie
        """
        leaf_indices = self.token_id_to_leaf[:, 1]
        parent = self._build_parent_map()

        rows, cols = [], []
        for i, node in enumerate(leaf_indices):
            # self connections
            rows.append(i)
            cols.append(node)

            current = node
            while current in parent:  # Walk up to root
                ancestor = parent[current]
                rows.append(i)
                cols.append(ancestor)
                current = ancestor

        self.src_indices = torch.tensor(rows, dtype=torch.long, device=self.device)
        self.dst_indices = torch.tensor(cols, dtype=torch.long, device=self.device)

        indices = torch.tensor([rows, cols], dtype=torch.long, device=self.device)
        values = torch.ones(len(rows), device=self.device)

        self.M = torch.sparse_coo_tensor(
            indices, values, (len(leaf_indices), len(self.children))
        ).to_sparse_csr()

    def _preprocess_ws(self, batch_ws):
        processed_batch_ws = []
        for ws in batch_ws:
            if not isinstance(ws, torch.Tensor):
                ws = torch.tensor(ws, device=self.device, dtype=torch.float32)
            elif ws.device != self.device or ws.dtype != torch.float32:
                ws = ws.to(device=self.device, dtype=torch.float32)
            assert ws.shape[0] == len(self.decode), [ws.shape[0], len(self.decode)]
            processed_batch_ws.append(ws)
        return torch.stack(processed_batch_ws)

    def weight_sum(self, ws):
        """Computes weight sums given token weights.

        For each node in the trie, this computes the sum of weights of all leaf nodes (tokens)
        that are descendants of that node. This is efficiently implemented using sparse matrix multiplication
        with a pre-computed reachability matrix.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).

        Returns:
            (numpy.ndarray): Summed weights for each node in the trie, shape (`len(self.decode)`,).
        """
        return self.batch_weight_sum(self._preprocess_ws([ws]))[0]

    def batch_weight_sum(self, ws):
        """Batch version of `weight_sum`.

        Args:
            ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).

        Returns:
            numpy.ndarray: Summed weights for each node in the trie, shape (batch_size × num_nodes).
        """
        ws = self._preprocess_ws(ws)
        masses = torch.sparse.mm(ws[:, self.token_ids], self.M)
        return masses.cpu().numpy()

    def weight_max(self, ws):
        """Computes the max weights given the token weights.

        For each node in the trie, this computes the maximum weight among all leaf nodes (tokens)
        that are descendants of that node. This is efficiently implemented using parallel scatter_reduce
        operations on GPU.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).

        Returns:
            (numpy.ndarray): Maximum weights for each node in the trie, shape (`len(self.decode)`,).
        """
        return self.batch_weight_max(self._preprocess_ws([ws]))[0]

    def batch_weight_max(self, ws):
        """Batch version of `weight_max`.

        Args:
            ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).

        Returns:
            (numpy.ndarray): Maximum weights for each node in the trie, shape (batch_size × num_nodes).
        """
        ws = self._preprocess_ws(ws)

        # Get leaf weights
        leaf_weights = ws[:, self.token_ids]  # shape: (batch_size × num_leafs)
        batch_size = leaf_weights.shape[0]

        # Use scatter_reduce to propagate maximum values in parallel
        result = torch.zeros((batch_size, len(self.children)), device=self.device)
        result.scatter_reduce_(
            dim=1,
            index=self.dst_indices.expand(batch_size, -1),
            src=leaf_weights[:, self.src_indices],
            reduce="amax",
            include_self=False,
        )

        return result.cpu().numpy()

weight_sum(ws)

Computes weight sums given token weights.

For each node in the trie, this computes the sum of weights of all leaf nodes (tokens) that are descendants of that node. This is efficiently implemented using sparse matrix multiplication with a pre-computed reachability matrix.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.decode),).

required

Returns:

Type Description
ndarray

Summed weights for each node in the trie, shape (len(self.decode),).

Source code in genlm/backend/trie/parallel.py
def weight_sum(self, ws):
    """Computes weight sums given token weights.

    For each node in the trie, this computes the sum of weights of all leaf nodes (tokens)
    that are descendants of that node. This is efficiently implemented using sparse matrix multiplication
    with a pre-computed reachability matrix.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).

    Returns:
        (numpy.ndarray): Summed weights for each node in the trie, shape (`len(self.decode)`,).
    """
    return self.batch_weight_sum(self._preprocess_ws([ws]))[0]

batch_weight_sum(ws)

Batch version of weight_sum.

Parameters:

Name Type Description Default
ws Tensor

Batch of token weights, shape (batch_size × len(self.decode)).

required

Returns:

Type Description

numpy.ndarray: Summed weights for each node in the trie, shape (batch_size × num_nodes).

Source code in genlm/backend/trie/parallel.py
def batch_weight_sum(self, ws):
    """Batch version of `weight_sum`.

    Args:
        ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).

    Returns:
        numpy.ndarray: Summed weights for each node in the trie, shape (batch_size × num_nodes).
    """
    ws = self._preprocess_ws(ws)
    masses = torch.sparse.mm(ws[:, self.token_ids], self.M)
    return masses.cpu().numpy()

weight_max(ws)

Computes the max weights given the token weights.

For each node in the trie, this computes the maximum weight among all leaf nodes (tokens) that are descendants of that node. This is efficiently implemented using parallel scatter_reduce operations on GPU.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.decode),).

required

Returns:

Type Description
ndarray

Maximum weights for each node in the trie, shape (len(self.decode),).

Source code in genlm/backend/trie/parallel.py
def weight_max(self, ws):
    """Computes the max weights given the token weights.

    For each node in the trie, this computes the maximum weight among all leaf nodes (tokens)
    that are descendants of that node. This is efficiently implemented using parallel scatter_reduce
    operations on GPU.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).

    Returns:
        (numpy.ndarray): Maximum weights for each node in the trie, shape (`len(self.decode)`,).
    """
    return self.batch_weight_max(self._preprocess_ws([ws]))[0]

batch_weight_max(ws)

Batch version of weight_max.

Parameters:

Name Type Description Default
ws Tensor

Batch of token weights, shape (batch_size × len(self.decode)).

required

Returns:

Type Description
ndarray

Maximum weights for each node in the trie, shape (batch_size × num_nodes).

Source code in genlm/backend/trie/parallel.py
def batch_weight_max(self, ws):
    """Batch version of `weight_max`.

    Args:
        ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).

    Returns:
        (numpy.ndarray): Maximum weights for each node in the trie, shape (batch_size × num_nodes).
    """
    ws = self._preprocess_ws(ws)

    # Get leaf weights
    leaf_weights = ws[:, self.token_ids]  # shape: (batch_size × num_leafs)
    batch_size = leaf_weights.shape[0]

    # Use scatter_reduce to propagate maximum values in parallel
    result = torch.zeros((batch_size, len(self.children)), device=self.device)
    result.scatter_reduce_(
        dim=1,
        index=self.dst_indices.expand(batch_size, -1),
        src=leaf_weights[:, self.src_indices],
        reduce="amax",
        include_self=False,
    )

    return result.cpu().numpy()

AsyncTokenCharacterTrie

An asynchronous wrapper for TokenCharacterTrie implementations that provides automatic request batching.

Source code in genlm/backend/trie/async_impl.py
class AsyncTokenCharacterTrie:
    """An asynchronous wrapper for TokenCharacterTrie implementations that provides automatic request batching."""

    def __init__(self, trie):
        """Initialize an `AsyncTokenCharacterTrie`.

        Args:
            trie (TokenCharacterTrie|ParallelTokenCharacterTrie): The underlying `TokenCharacterTrie` or `ParallelTokenCharacterTrie` instance
        """
        self.trie = trie
        self._queue = None
        self._task = None

    @classmethod
    def from_vocab(cls, vocab, backend="parallel", **kwargs):
        """Creates an `AsyncTokenCharacterTrie` from a vocabulary.

        Args:
            vocab (list): The vocabulary over which the trie will be defined.
            backend (str, optional): The trie implementation to use - either 'sequential' or 'parallel'.
                    Defaults to 'parallel' which uses GPU acceleration when available.
            **kwargs: Additional arguments passed to the trie constructor

        Returns:
            (AsyncTokenCharacterTrie): The initialized asynchronous trie instance.
        """
        if backend == "sequential":
            trie = TokenCharacterTrie(decode=vocab, **kwargs)
        elif backend == "parallel":
            trie = ParallelTokenCharacterTrie(decode=vocab, **kwargs)
        else:
            raise ValueError(
                f"Unknown backend: {backend}. Must be one of ['sequential', 'parallel']"
            )
        return cls(trie)

    async def _queue_request(self, request, op):
        if not self._task or self._task.done():
            self.start()

        future = asyncio.Future()
        await self._queue.put((request, future, op))
        return future

    async def weight_sum(self, ws):
        """Queue a `weight_sum` request. Multiple concurrent calls will be automatically batched
        together.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).

        Returns:
            (np.ndarray): The calculated mass sums for the given distribution.
        """
        future = await self._queue_request(ws, "sum")
        result = await future
        return result

    async def weight_max(self, ws):
        """Queue a `weight_max` request. Multiple concurrent calls will be automatically batched
        together.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).

        Returns:
            (np.ndarray): The calculated max weights for the given distribution.
        """
        future = await self._queue_request(ws, "max")
        result = await future
        return result

    def start(self):
        """Start the background processing task if not already running."""
        if not self._task or self._task.done():
            self._queue = (
                asyncio.Queue()
            )  # Create a new queue so that it is bound to the current event loop
            self._task = asyncio.create_task(self._background_loop())

    def _do_weight_sums(self, batch_weights):
        return self.trie.batch_weight_sum(batch_weights)

    def _do_weight_maxs(self, batch_weights):
        return self.trie.batch_weight_max(batch_weights)

    async def _background_loop(self):
        """Background task that processes queued weight sum and max requests.

        Continuously monitors the queue for new requests and processes them in batches
        using the underlying trie implementation.

        Raises:
            Exception: If any error occurs during processing, it is propagated to all
                      pending futures in the current batch.
        """
        while True:
            try:
                op_groups = defaultdict(list)

                request, future, op = await self._queue.get()
                op_groups[op].append((request, future))

                while not self._queue.empty():
                    request, future, op = await self._queue.get()
                    op_groups[op].append((request, future))

                for op, group in op_groups.items():
                    requests, futures = zip(*group)

                    if op == "sum":
                        logger.debug(f"processing {len(requests)} sum requests")
                        results = self._do_weight_sums(requests)
                    elif op == "max":
                        logger.debug(f"processing {len(requests)} max requests")
                        results = self._do_weight_maxs(requests)
                    else:
                        raise ValueError(f"Unknown operation: {op}")

                    for future, result in zip(futures, results):
                        future.set_result(result)

            except Exception as e:
                for group in op_groups.values():
                    for _, future in group:
                        if not future.done():
                            future.set_exception(e)
                raise

    async def cleanup(self):
        """Async cleanup - preferred method"""
        if self._task and not self._task.done():
            self._task.cancel()
            try:
                await self._task
            except asyncio.CancelledError:
                pass
            self._task = None

    def shutdown(self):
        """Stop the background processing task and cleanup resources."""
        if self._task is not None:
            try:
                self._task.cancel()
            except RuntimeError:
                # Ignore runtime errors that might occur if event loop is closed
                pass
            self._task = None

    def __del__(self):
        self.shutdown()

__init__(trie)

Initialize an AsyncTokenCharacterTrie.

Parameters:

Name Type Description Default
trie TokenCharacterTrie | ParallelTokenCharacterTrie

The underlying TokenCharacterTrie or ParallelTokenCharacterTrie instance

required
Source code in genlm/backend/trie/async_impl.py
def __init__(self, trie):
    """Initialize an `AsyncTokenCharacterTrie`.

    Args:
        trie (TokenCharacterTrie|ParallelTokenCharacterTrie): The underlying `TokenCharacterTrie` or `ParallelTokenCharacterTrie` instance
    """
    self.trie = trie
    self._queue = None
    self._task = None

from_vocab(vocab, backend='parallel', **kwargs) classmethod

Creates an AsyncTokenCharacterTrie from a vocabulary.

Parameters:

Name Type Description Default
vocab list

The vocabulary over which the trie will be defined.

required
backend str

The trie implementation to use - either 'sequential' or 'parallel'. Defaults to 'parallel' which uses GPU acceleration when available.

'parallel'
**kwargs

Additional arguments passed to the trie constructor

{}

Returns:

Type Description
AsyncTokenCharacterTrie

The initialized asynchronous trie instance.

Source code in genlm/backend/trie/async_impl.py
@classmethod
def from_vocab(cls, vocab, backend="parallel", **kwargs):
    """Creates an `AsyncTokenCharacterTrie` from a vocabulary.

    Args:
        vocab (list): The vocabulary over which the trie will be defined.
        backend (str, optional): The trie implementation to use - either 'sequential' or 'parallel'.
                Defaults to 'parallel' which uses GPU acceleration when available.
        **kwargs: Additional arguments passed to the trie constructor

    Returns:
        (AsyncTokenCharacterTrie): The initialized asynchronous trie instance.
    """
    if backend == "sequential":
        trie = TokenCharacterTrie(decode=vocab, **kwargs)
    elif backend == "parallel":
        trie = ParallelTokenCharacterTrie(decode=vocab, **kwargs)
    else:
        raise ValueError(
            f"Unknown backend: {backend}. Must be one of ['sequential', 'parallel']"
        )
    return cls(trie)

weight_sum(ws) async

Queue a weight_sum request. Multiple concurrent calls will be automatically batched together.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.trie.decode),).

required

Returns:

Type Description
ndarray

The calculated mass sums for the given distribution.

Source code in genlm/backend/trie/async_impl.py
async def weight_sum(self, ws):
    """Queue a `weight_sum` request. Multiple concurrent calls will be automatically batched
    together.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).

    Returns:
        (np.ndarray): The calculated mass sums for the given distribution.
    """
    future = await self._queue_request(ws, "sum")
    result = await future
    return result

weight_max(ws) async

Queue a weight_max request. Multiple concurrent calls will be automatically batched together.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.trie.decode),).

required

Returns:

Type Description
ndarray

The calculated max weights for the given distribution.

Source code in genlm/backend/trie/async_impl.py
async def weight_max(self, ws):
    """Queue a `weight_max` request. Multiple concurrent calls will be automatically batched
    together.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).

    Returns:
        (np.ndarray): The calculated max weights for the given distribution.
    """
    future = await self._queue_request(ws, "max")
    result = await future
    return result

start()

Start the background processing task if not already running.

Source code in genlm/backend/trie/async_impl.py
def start(self):
    """Start the background processing task if not already running."""
    if not self._task or self._task.done():
        self._queue = (
            asyncio.Queue()
        )  # Create a new queue so that it is bound to the current event loop
        self._task = asyncio.create_task(self._background_loop())

cleanup() async

Async cleanup - preferred method

Source code in genlm/backend/trie/async_impl.py
async def cleanup(self):
    """Async cleanup - preferred method"""
    if self._task and not self._task.done():
        self._task.cancel()
        try:
            await self._task
        except asyncio.CancelledError:
            pass
        self._task = None

shutdown()

Stop the background processing task and cleanup resources.

Source code in genlm/backend/trie/async_impl.py
def shutdown(self):
    """Stop the background processing task and cleanup resources."""
    if self._task is not None:
        try:
            self._task.cancel()
        except RuntimeError:
            # Ignore runtime errors that might occur if event loop is closed
            pass
        self._task = None