Skip to content

backend

AsyncVirtualLM

Bases: AsyncLM

Async language model using vLLM v1 with global logits processor.

This implementation uses vLLM v1's in-process mode with a global logits processor to efficiently capture full vocabulary log probabilities.

Source code in genlm/backend/llm/vllm.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
class AsyncVirtualLM(AsyncLM):  # pragma: no cover
    """Async language model using vLLM v1 with global logits processor.

    This implementation uses vLLM v1's in-process mode with a global
    logits processor to efficiently capture full vocabulary log probabilities.
    """

    default_params = {
        "max_tokens": 1,
        "n": 1,
        "detokenize": False,
        "stop": None,
        "ignore_eos": True,
    }

    def __init__(
        self,
        llm_engine,
        logprobs_capture,
        cache_size=0,
        cache_opts=None,
        batch_size=20,
        timeout=0.02,
    ):
        """Initialize an `AsyncVirtualLM` instance.

        Args:
            llm_engine (LLM): The vLLM engine instance.
            logprobs_capture (GlobalLogprobsCapture): The global logprobs capture processor.
            cache_size (int, optional): Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.
            cache_opts (dict, optional): Additional options to pass to the [`OutputCache`][genlm.backend.cache.OutputCache] constructor. Defaults to None (no extra options).
            batch_size (int, optional): Maximum queries to process in one batch during auto-batching. Defaults to 20.
            timeout (float, optional): Seconds to wait after the first queued query before processing the current batch. The batch also fires immediately when ``batch_size`` is reached. Defaults to 0.02.

        Note:
            The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine.
            ``batch_next_token_logprobs_sync`` bypasses this cache and always re-evaluates; the other three logprobs methods consult it.
        """
        self.llm_engine = llm_engine
        self.logprobs_capture = logprobs_capture
        self.tokenizer = llm_engine.get_tokenizer()
        self.cache = (
            OutputCache(maxsize=cache_size, **(cache_opts or {}))
            if cache_size > 0
            else None
        )
        self.lora_request = None
        self.lora_name_to_ids = {}

        self.queries = []
        self.batch_size = batch_size
        self.timeout = timeout
        self.timer = None

        self.sample_queries = []
        self.sample_timer = None

        super().__init__(tokenizer=self.tokenizer)

    @classmethod
    def from_name(cls, model_name, engine_opts=None, **kwargs):
        """Create a `AsyncVirtualLM` instance from a model name.

        Args:
            model_name (str): Name of the model to load.
            engine_opts (dict): Additional options to pass to the `LLM` engine.
            **kwargs: Additional arguments passed to `AsyncVirtualLM` constructor.

        Returns:
            (AsyncVirtualLM): An `AsyncVirtualLM` instance.
        """
        if not HAS_VLLM:
            raise ImportError(  # pragma: no cover
                "vLLM not available. Install vLLM or use AsyncTransformer instead."
            )

        engine_opts = {
            "enable_prefix_caching": True,
            "disable_log_stats": True,
            "gpu_memory_utilization": 0.9,
            **(engine_opts or {}),
        }

        llm = LLM(model=model_name, tokenizer=model_name, **engine_opts)

        logprobs_capture = GlobalLogprobsCapture()
        model_runner = cls._get_model_runner(llm)
        model_runner.input_batch.logitsprocs.argmax_invariant.append(
            logprobs_capture
        )

        return cls(llm, logprobs_capture, **kwargs)

    @staticmethod
    def _get_model_runner(llm):
        """Walk the vLLM v1 internals to reach the driver worker's model runner.

        This path is brittle against vLLM refactors, so it lives in one
        place and is reused by ``from_name`` (to inject the logits
        processor) and ``underlying_model``.
        """
        engine_core = llm.llm_engine.engine_core.engine_core
        return engine_core.model_executor.driver_worker.worker.model_runner

    @property
    def underlying_model(self):
        """Access the underlying model for advanced use cases."""
        return self._get_model_runner(self.llm_engine).model

    def clear_lora(self):
        """
        Disable any active LoRA adapter for the vLLM engine.
        """
        self.lora_request = None

    def add_new_lora(self, lora_path, lora_name="lora_1"):
        """Load a LoRA adapter into the base model by creating a unique id for it.

        Args:
            lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
            lora_name (str): Name to assign to the loaded adapter.

        Notes:
            This does not activate the adapter immediately. Call `set_lora()` to enable the adapter.
        """
        self.lora_name_to_ids[lora_name] = self.hash_to_int(lora_name)

    def hash_to_int(self, value):
        """Generates a deterministic unique id for a LoRA adapter from its name.

        Args:
            value (str): The name of the LoRA adapter to hash.

        Returns:
            An integer ID corresponding to the LoRA adapter, in the range [1, 2^31 - 1].
        """
        hash_bytes = hashlib.shake_128(value.encode("utf-8")).digest(4)
        return (int.from_bytes(hash_bytes, "big") % (2**31 - 2)) + 1

    def set_lora(self, lora_path, lora_name="lora_1"):
        """Configure a LoRA adapter request for the vLLM engine.

        Args:
            lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
            lora_name (str): Identifier name to associate with this LoRA adapter within vLLM.
            lora_id (int): Globally unique ID for the adapter.
        """
        if lora_name not in self.lora_name_to_ids:
            raise ValueError(
                f"A LoRA adapter named '{lora_name}' has not been loaded yet. Please call add_new_lora() first to load and name your LoRA adapters."
            )
        self.lora_request = LoRARequest(
            lora_name, self.lora_name_to_ids[lora_name], lora_path
        )

    async def next_token_logprobs(self, token_ids):
        """Request log probabilities of next token asynchronously with auto-batching.

        Concurrent calls to this method are automatically batched into a single
        ``LLM.generate()`` call for efficiency. Use with ``await``.

        Args:
            token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

        Returns:
            result (torch.Tensor): Normalized log probability tensor.
        """
        key = tuple(token_ids)

        if self.cache is not None and key in self.cache:
            return self.cache[key]

        future = asyncio.get_running_loop().create_future()
        self._add_query(token_ids, future)
        result = await future

        if self.cache is not None:
            self.cache[key] = result

        return result

    def _add_query(self, token_ids, future):
        """Add a query to be evaluated in the next batch.

        The timeout is measured from the *first* queued query, not the most
        recent one: we only arm the timer when the queue transitions from
        empty to non-empty. This prevents starvation when queries trickle in
        faster than ``self.timeout`` but never fill a batch.

        Args:
            token_ids (list[int]): Token IDs representing the query prompt.
            future (asyncio.Future): Future to store the result in.
        """
        self.queries.append((token_ids, future))

        if len(self.queries) >= self.batch_size:
            if self.timer:
                self.timer.cancel()
                self.timer = None
            self._batch_evaluate()
        elif self.timer is None:
            self.timer = asyncio.get_running_loop().call_later(
                self.timeout, self._batch_evaluate
            )

    def _batch_evaluate(self):
        """Process all queued queries in a single batched ``generate()`` call."""
        queries, self.queries = self.queries, []
        if not queries:
            return

        if self.timer:
            self.timer.cancel()
            self.timer = None

        if self.logprobs_capture is None:
            exc = RuntimeError("Cannot use model after cleanup() has been called")
            for _, future in queries:
                future.set_exception(exc)
            return

        # Deduplicate: group futures by identical prompts
        query_groups = defaultdict(list)
        for token_ids, future in queries:
            query_groups[tuple(token_ids)].append(future)

        unique_token_ids = list(query_groups.keys())

        self.logprobs_capture.clear()

        prompts = [
            TokensPrompt(prompt_token_ids=list(token_ids))
            for token_ids in unique_token_ids
        ]

        try:
            self.llm_engine.generate(
                prompts=prompts,
                sampling_params=SamplingParams(**self.default_params),
                lora_request=self.lora_request,
                use_tqdm=False,
            )

            all_logprobs = self.logprobs_capture.get_all_logprobs()
            assert all_logprobs is not None, "Logprobs should be captured"
            assert all_logprobs.shape[0] == len(unique_token_ids), (
                f"Expected {len(unique_token_ids)} logprobs, got {all_logprobs.shape[0]}"
            )

            for i, key in enumerate(unique_token_ids):
                logprobs = all_logprobs[i]
                futures = query_groups[key]
                if len(futures) == 1:
                    futures[0].set_result(logprobs)
                else:
                    for future in futures:
                        future.set_result(logprobs.clone())
        except Exception as exc:
            for futures in query_groups.values():
                for future in futures:
                    if not future.done():
                        future.set_exception(exc)

    def reset_async_queries(self):
        """Clear any pending queries from the queue.

        Use this method when an exception prevented an inference algorithm
        from executing to completion.
        """
        self.queries = []
        if self.timer:
            self.timer.cancel()
            self.timer = None

        self.sample_queries = []
        if self.sample_timer:
            self.sample_timer.cancel()
            self.sample_timer = None

    def next_token_logprobs_sync(self, token_ids):
        """Request log probabilities of next token synchronously.

        Does not support auto-batching. For batched sync calls, use
        ``batch_next_token_logprobs_sync`` instead.

        Args:
            token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        key = tuple(token_ids)

        if self.cache is not None and key in self.cache:
            return self.cache[key]

        if self.logprobs_capture is None:
            raise RuntimeError("Cannot use model after cleanup() has been called")

        self.logprobs_capture.clear()

        self.llm_engine.generate(
            prompts=TokensPrompt(prompt_token_ids=list(token_ids)),
            sampling_params=SamplingParams(**self.default_params),
            lora_request=self.lora_request,
            use_tqdm=False,
        )

        result = self.logprobs_capture.get_logprobs(batch_index=0)
        assert result is not None, "Logprobs should be captured by global processor"

        if self.cache is not None:
            self.cache[key] = result

        return result

    def batch_next_token_logprobs_sync(self, token_ids_list):
        """
        Request log probabilities of next tokens in a batch synchronously.

        Args:
            token_ids_list (list[list[int]]): A list of token ID lists, each representing a prompt to the language model.

        Returns:
            (torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list.

        Note:
            This method does not consult the output cache (unlike the async batch path,
            which delegates to the cached ``next_token_logprobs``). Every prompt is
            re-evaluated.
        """
        if self.logprobs_capture is None:
            raise RuntimeError("Cannot use model after cleanup() has been called")
        # Clear any stale captured logprobs
        self.logprobs_capture.clear()

        # Create prompts for batch
        prompts = [
            TokensPrompt(prompt_token_ids=list(token_ids))
            for token_ids in token_ids_list
        ]

        # Generate one token for each prompt
        self.llm_engine.generate(
            prompts=prompts,
            sampling_params=SamplingParams(**self.default_params),
            lora_request=self.lora_request,
            use_tqdm=False,
        )

        # Get all captured logprobs at once (optimized - single clone)
        all_logprobs = self.logprobs_capture.get_all_logprobs()
        assert all_logprobs is not None, "Logprobs should be captured"
        assert all_logprobs.shape[0] == len(token_ids_list), (
            f"Expected {len(token_ids_list)} logprobs, got {all_logprobs.shape[0]}"
        )

        return all_logprobs

    def clear_cache(self):
        """Clear output cache."""
        if self.cache:
            self.cache.clear()

    def cleanup(self):
        """Explicitly clean up GPU resources. Call this when done with the model."""
        self._cleanup_engine()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.cleanup()
        return False

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        self.cleanup()
        return False

    def __del__(self):
        """Clean up resources on deletion."""
        self._cleanup_engine()

    def _cleanup_engine(self):
        """Clean up the vLLM engine and associated resources.

        This is invoked from both :meth:`cleanup` (explicit, during normal
        program flow) and :meth:`__del__` (implicit, possibly at
        interpreter shutdown). The narrow exception classes below cover
        the races and idempotency issues we know about:

        * ``ImportError`` / ``AttributeError`` arise when ``__del__`` runs
          after ``sys.meta_path`` is already torn down during interpreter
          shutdown.
        * ``AssertionError`` is raised by vLLM's
          ``destroy_distributed_environment`` if it's called twice.
        * ``RuntimeError`` can surface from CUDA when the driver is
          already being torn down.

        Anything else is re-raised so real bugs are not swallowed.
        """
        try:
            # ``import gc`` can itself raise ImportError when ``__del__`` is
            # invoked after ``sys.meta_path`` has been torn down at
            # interpreter shutdown, so it lives inside the try block.
            import gc

            # Clear our references
            if hasattr(self, "logprobs_capture"):
                if self.logprobs_capture is not None:
                    self.logprobs_capture.clear()
                self.logprobs_capture = None

            # Delete the engine to free GPU memory
            if hasattr(self, "llm_engine") and self.llm_engine is not None:
                del self.llm_engine
                self.llm_engine = None

            # Force garbage collection
            gc.collect()

            # Clear CUDA cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()

            # Clean up distributed state
            destroy_model_parallel()
            destroy_distributed_environment()
        except (
            ImportError,
            AttributeError,
            AssertionError,
            RuntimeError,
        ) as e:
            # Best-effort log; during interpreter shutdown logging itself
            # may already be torn down, in which case silently drop.
            with contextlib.suppress(Exception):
                logging.getLogger(__name__).debug(
                    "AsyncVirtualLM cleanup raised %s: %s",
                    type(e).__name__,
                    e,
                )

    def _add_sample_query(self, prompt_token_ids, sampling_params, future):
        """Enqueue a ``sample()`` request; mirrors ``_add_query`` for the logprobs path."""
        self.sample_queries.append((prompt_token_ids, sampling_params, future))
        if len(self.sample_queries) >= self.batch_size:
            if self.sample_timer:
                self.sample_timer.cancel()
                self.sample_timer = None
            self._batch_sample_evaluate()
        elif self.sample_timer is None:
            self.sample_timer = asyncio.get_running_loop().call_later(
                self.timeout, self._batch_sample_evaluate
            )

    def _batch_sample_evaluate(self):
        """Dispatch queued ``sample()`` requests in one batched ``generate()`` call."""
        queries, self.sample_queries = self.sample_queries, []
        if not queries:
            return
        if self.sample_timer:
            self.sample_timer.cancel()
            self.sample_timer = None
        if self.logprobs_capture is None:
            exc = RuntimeError("Cannot use model after cleanup() has been called")
            for _, _, future in queries:
                future.set_exception(exc)
            return
        try:
            outputs = self.llm_engine.generate(
                prompts=[TokensPrompt(prompt_token_ids=t) for t, _, _ in queries],
                sampling_params=[sp for _, sp, _ in queries],
                lora_request=self.lora_request,
                use_tqdm=False,
            )
            assert len(outputs) == len(queries)
            for output, (_, _, future) in zip(outputs, queries):
                future.set_result(list(output.outputs[0].token_ids))
        except Exception as exc:
            for _, _, future in queries:
                if not future.done():
                    future.set_exception(exc)

    async def sample(
        self,
        prompt_token_ids,
        max_tokens,
        eos_token_ids,
        temperature=1.0,
        seed=None,
    ):
        """Sample from the language model.

        Concurrent calls are auto-batched into a single ``LLM.generate()``
        so vLLM continuous-batches the decode steps. Use with ``await``.

        Args:
            prompt_token_ids (list[int]): The token IDs of the prompt.
            eos_token_ids (list[int]): The token IDs of the end-of-sequence tokens.
            temperature (float, optional): The temperature to use to rescale the logits. Defaults to 1.0.
            max_tokens (int): The maximum number of tokens to generate.
            seed (int, optional): The seed for the random number generator. Defaults to None.

        Returns:
            (list[int]): The sampled token IDs.
        """
        future = asyncio.get_running_loop().create_future()
        self._add_sample_query(
            list(prompt_token_ids),
            SamplingParams(
                n=1,
                max_tokens=max_tokens,
                temperature=temperature,
                seed=seed,
                stop=[self.byte_vocab[i].decode() for i in eos_token_ids],
            ),
            future,
        )
        token_ids = await future
        if token_ids and token_ids[-1] in eos_token_ids:
            token_ids = token_ids[:-1]
        return token_ids

__init__(llm_engine, logprobs_capture, cache_size=0, cache_opts=None, batch_size=20, timeout=0.02)

Initialize an AsyncVirtualLM instance.

Parameters:

Name Type Description Default
llm_engine LLM

The vLLM engine instance.

required
logprobs_capture GlobalLogprobsCapture

The global logprobs capture processor.

required
cache_size int

Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.

0
cache_opts dict

Additional options to pass to the OutputCache constructor. Defaults to None (no extra options).

None
batch_size int

Maximum queries to process in one batch during auto-batching. Defaults to 20.

20
timeout float

Seconds to wait after the first queued query before processing the current batch. The batch also fires immediately when batch_size is reached. Defaults to 0.02.

0.02
Note

The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine. batch_next_token_logprobs_sync bypasses this cache and always re-evaluates; the other three logprobs methods consult it.

Source code in genlm/backend/llm/vllm.py
def __init__(
    self,
    llm_engine,
    logprobs_capture,
    cache_size=0,
    cache_opts=None,
    batch_size=20,
    timeout=0.02,
):
    """Initialize an `AsyncVirtualLM` instance.

    Args:
        llm_engine (LLM): The vLLM engine instance.
        logprobs_capture (GlobalLogprobsCapture): The global logprobs capture processor.
        cache_size (int, optional): Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.
        cache_opts (dict, optional): Additional options to pass to the [`OutputCache`][genlm.backend.cache.OutputCache] constructor. Defaults to None (no extra options).
        batch_size (int, optional): Maximum queries to process in one batch during auto-batching. Defaults to 20.
        timeout (float, optional): Seconds to wait after the first queued query before processing the current batch. The batch also fires immediately when ``batch_size`` is reached. Defaults to 0.02.

    Note:
        The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine.
        ``batch_next_token_logprobs_sync`` bypasses this cache and always re-evaluates; the other three logprobs methods consult it.
    """
    self.llm_engine = llm_engine
    self.logprobs_capture = logprobs_capture
    self.tokenizer = llm_engine.get_tokenizer()
    self.cache = (
        OutputCache(maxsize=cache_size, **(cache_opts or {}))
        if cache_size > 0
        else None
    )
    self.lora_request = None
    self.lora_name_to_ids = {}

    self.queries = []
    self.batch_size = batch_size
    self.timeout = timeout
    self.timer = None

    self.sample_queries = []
    self.sample_timer = None

    super().__init__(tokenizer=self.tokenizer)

from_name(model_name, engine_opts=None, **kwargs) classmethod

Create a AsyncVirtualLM instance from a model name.

Parameters:

Name Type Description Default
model_name str

Name of the model to load.

required
engine_opts dict

Additional options to pass to the LLM engine.

None
**kwargs

Additional arguments passed to AsyncVirtualLM constructor.

{}

Returns:

Type Description
AsyncVirtualLM

An AsyncVirtualLM instance.

Source code in genlm/backend/llm/vllm.py
@classmethod
def from_name(cls, model_name, engine_opts=None, **kwargs):
    """Create a `AsyncVirtualLM` instance from a model name.

    Args:
        model_name (str): Name of the model to load.
        engine_opts (dict): Additional options to pass to the `LLM` engine.
        **kwargs: Additional arguments passed to `AsyncVirtualLM` constructor.

    Returns:
        (AsyncVirtualLM): An `AsyncVirtualLM` instance.
    """
    if not HAS_VLLM:
        raise ImportError(  # pragma: no cover
            "vLLM not available. Install vLLM or use AsyncTransformer instead."
        )

    engine_opts = {
        "enable_prefix_caching": True,
        "disable_log_stats": True,
        "gpu_memory_utilization": 0.9,
        **(engine_opts or {}),
    }

    llm = LLM(model=model_name, tokenizer=model_name, **engine_opts)

    logprobs_capture = GlobalLogprobsCapture()
    model_runner = cls._get_model_runner(llm)
    model_runner.input_batch.logitsprocs.argmax_invariant.append(
        logprobs_capture
    )

    return cls(llm, logprobs_capture, **kwargs)

underlying_model property

Access the underlying model for advanced use cases.

clear_lora()

Disable any active LoRA adapter for the vLLM engine.

Source code in genlm/backend/llm/vllm.py
def clear_lora(self):
    """
    Disable any active LoRA adapter for the vLLM engine.
    """
    self.lora_request = None

add_new_lora(lora_path, lora_name='lora_1')

Load a LoRA adapter into the base model by creating a unique id for it.

Parameters:

Name Type Description Default
lora_path str

Path to the adapter weights directory or identifier in HuggingFace's model hub.

required
lora_name str

Name to assign to the loaded adapter.

'lora_1'
Notes

This does not activate the adapter immediately. Call set_lora() to enable the adapter.

Source code in genlm/backend/llm/vllm.py
def add_new_lora(self, lora_path, lora_name="lora_1"):
    """Load a LoRA adapter into the base model by creating a unique id for it.

    Args:
        lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
        lora_name (str): Name to assign to the loaded adapter.

    Notes:
        This does not activate the adapter immediately. Call `set_lora()` to enable the adapter.
    """
    self.lora_name_to_ids[lora_name] = self.hash_to_int(lora_name)

hash_to_int(value)

Generates a deterministic unique id for a LoRA adapter from its name.

Parameters:

Name Type Description Default
value str

The name of the LoRA adapter to hash.

required

Returns:

Type Description

An integer ID corresponding to the LoRA adapter, in the range [1, 2^31 - 1].

Source code in genlm/backend/llm/vllm.py
def hash_to_int(self, value):
    """Generates a deterministic unique id for a LoRA adapter from its name.

    Args:
        value (str): The name of the LoRA adapter to hash.

    Returns:
        An integer ID corresponding to the LoRA adapter, in the range [1, 2^31 - 1].
    """
    hash_bytes = hashlib.shake_128(value.encode("utf-8")).digest(4)
    return (int.from_bytes(hash_bytes, "big") % (2**31 - 2)) + 1

set_lora(lora_path, lora_name='lora_1')

Configure a LoRA adapter request for the vLLM engine.

Parameters:

Name Type Description Default
lora_path str

Path to the adapter weights directory or identifier in HuggingFace's model hub.

required
lora_name str

Identifier name to associate with this LoRA adapter within vLLM.

'lora_1'
lora_id int

Globally unique ID for the adapter.

required
Source code in genlm/backend/llm/vllm.py
def set_lora(self, lora_path, lora_name="lora_1"):
    """Configure a LoRA adapter request for the vLLM engine.

    Args:
        lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
        lora_name (str): Identifier name to associate with this LoRA adapter within vLLM.
        lora_id (int): Globally unique ID for the adapter.
    """
    if lora_name not in self.lora_name_to_ids:
        raise ValueError(
            f"A LoRA adapter named '{lora_name}' has not been loaded yet. Please call add_new_lora() first to load and name your LoRA adapters."
        )
    self.lora_request = LoRARequest(
        lora_name, self.lora_name_to_ids[lora_name], lora_path
    )

next_token_logprobs(token_ids) async

Request log probabilities of next token asynchronously with auto-batching.

Concurrent calls to this method are automatically batched into a single LLM.generate() call for efficiency. Use with await.

Parameters:

Name Type Description Default
token_ids list[int]

A list of token IDs, representing a prompt to the language model.

required

Returns:

Name Type Description
result Tensor

Normalized log probability tensor.

Source code in genlm/backend/llm/vllm.py
async def next_token_logprobs(self, token_ids):
    """Request log probabilities of next token asynchronously with auto-batching.

    Concurrent calls to this method are automatically batched into a single
    ``LLM.generate()`` call for efficiency. Use with ``await``.

    Args:
        token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

    Returns:
        result (torch.Tensor): Normalized log probability tensor.
    """
    key = tuple(token_ids)

    if self.cache is not None and key in self.cache:
        return self.cache[key]

    future = asyncio.get_running_loop().create_future()
    self._add_query(token_ids, future)
    result = await future

    if self.cache is not None:
        self.cache[key] = result

    return result

reset_async_queries()

Clear any pending queries from the queue.

Use this method when an exception prevented an inference algorithm from executing to completion.

Source code in genlm/backend/llm/vllm.py
def reset_async_queries(self):
    """Clear any pending queries from the queue.

    Use this method when an exception prevented an inference algorithm
    from executing to completion.
    """
    self.queries = []
    if self.timer:
        self.timer.cancel()
        self.timer = None

    self.sample_queries = []
    if self.sample_timer:
        self.sample_timer.cancel()
        self.sample_timer = None

next_token_logprobs_sync(token_ids)

Request log probabilities of next token synchronously.

Does not support auto-batching. For batched sync calls, use batch_next_token_logprobs_sync instead.

Parameters:

Name Type Description Default
token_ids list[int]

A list of token IDs, representing a prompt to the language model.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm/backend/llm/vllm.py
def next_token_logprobs_sync(self, token_ids):
    """Request log probabilities of next token synchronously.

    Does not support auto-batching. For batched sync calls, use
    ``batch_next_token_logprobs_sync`` instead.

    Args:
        token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    key = tuple(token_ids)

    if self.cache is not None and key in self.cache:
        return self.cache[key]

    if self.logprobs_capture is None:
        raise RuntimeError("Cannot use model after cleanup() has been called")

    self.logprobs_capture.clear()

    self.llm_engine.generate(
        prompts=TokensPrompt(prompt_token_ids=list(token_ids)),
        sampling_params=SamplingParams(**self.default_params),
        lora_request=self.lora_request,
        use_tqdm=False,
    )

    result = self.logprobs_capture.get_logprobs(batch_index=0)
    assert result is not None, "Logprobs should be captured by global processor"

    if self.cache is not None:
        self.cache[key] = result

    return result

batch_next_token_logprobs_sync(token_ids_list)

Request log probabilities of next tokens in a batch synchronously.

Parameters:

Name Type Description Default
token_ids_list list[list[int]]

A list of token ID lists, each representing a prompt to the language model.

required

Returns:

Type Description
Tensor

A tensor of normalized log probability tensors, one for each prompt in the input list.

Note

This method does not consult the output cache (unlike the async batch path, which delegates to the cached next_token_logprobs). Every prompt is re-evaluated.

Source code in genlm/backend/llm/vllm.py
def batch_next_token_logprobs_sync(self, token_ids_list):
    """
    Request log probabilities of next tokens in a batch synchronously.

    Args:
        token_ids_list (list[list[int]]): A list of token ID lists, each representing a prompt to the language model.

    Returns:
        (torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list.

    Note:
        This method does not consult the output cache (unlike the async batch path,
        which delegates to the cached ``next_token_logprobs``). Every prompt is
        re-evaluated.
    """
    if self.logprobs_capture is None:
        raise RuntimeError("Cannot use model after cleanup() has been called")
    # Clear any stale captured logprobs
    self.logprobs_capture.clear()

    # Create prompts for batch
    prompts = [
        TokensPrompt(prompt_token_ids=list(token_ids))
        for token_ids in token_ids_list
    ]

    # Generate one token for each prompt
    self.llm_engine.generate(
        prompts=prompts,
        sampling_params=SamplingParams(**self.default_params),
        lora_request=self.lora_request,
        use_tqdm=False,
    )

    # Get all captured logprobs at once (optimized - single clone)
    all_logprobs = self.logprobs_capture.get_all_logprobs()
    assert all_logprobs is not None, "Logprobs should be captured"
    assert all_logprobs.shape[0] == len(token_ids_list), (
        f"Expected {len(token_ids_list)} logprobs, got {all_logprobs.shape[0]}"
    )

    return all_logprobs

clear_cache()

Clear output cache.

Source code in genlm/backend/llm/vllm.py
def clear_cache(self):
    """Clear output cache."""
    if self.cache:
        self.cache.clear()

cleanup()

Explicitly clean up GPU resources. Call this when done with the model.

Source code in genlm/backend/llm/vllm.py
def cleanup(self):
    """Explicitly clean up GPU resources. Call this when done with the model."""
    self._cleanup_engine()

__del__()

Clean up resources on deletion.

Source code in genlm/backend/llm/vllm.py
def __del__(self):
    """Clean up resources on deletion."""
    self._cleanup_engine()

sample(prompt_token_ids, max_tokens, eos_token_ids, temperature=1.0, seed=None) async

Sample from the language model.

Concurrent calls are auto-batched into a single LLM.generate() so vLLM continuous-batches the decode steps. Use with await.

Parameters:

Name Type Description Default
prompt_token_ids list[int]

The token IDs of the prompt.

required
eos_token_ids list[int]

The token IDs of the end-of-sequence tokens.

required
temperature float

The temperature to use to rescale the logits. Defaults to 1.0.

1.0
max_tokens int

The maximum number of tokens to generate.

required
seed int

The seed for the random number generator. Defaults to None.

None

Returns:

Type Description
list[int]

The sampled token IDs.

Source code in genlm/backend/llm/vllm.py
async def sample(
    self,
    prompt_token_ids,
    max_tokens,
    eos_token_ids,
    temperature=1.0,
    seed=None,
):
    """Sample from the language model.

    Concurrent calls are auto-batched into a single ``LLM.generate()``
    so vLLM continuous-batches the decode steps. Use with ``await``.

    Args:
        prompt_token_ids (list[int]): The token IDs of the prompt.
        eos_token_ids (list[int]): The token IDs of the end-of-sequence tokens.
        temperature (float, optional): The temperature to use to rescale the logits. Defaults to 1.0.
        max_tokens (int): The maximum number of tokens to generate.
        seed (int, optional): The seed for the random number generator. Defaults to None.

    Returns:
        (list[int]): The sampled token IDs.
    """
    future = asyncio.get_running_loop().create_future()
    self._add_sample_query(
        list(prompt_token_ids),
        SamplingParams(
            n=1,
            max_tokens=max_tokens,
            temperature=temperature,
            seed=seed,
            stop=[self.byte_vocab[i].decode() for i in eos_token_ids],
        ),
        future,
    )
    token_ids = await future
    if token_ids and token_ids[-1] in eos_token_ids:
        token_ids = token_ids[:-1]
    return token_ids

AsyncTransformer

Bases: AsyncLM

Asynchronous wrapper around a HuggingFace causal language model with caching support.

This class provides an asynchronous interface to HuggingFace language models with automatic batching and caching (output and KV) for improved efficiency.

Source code in genlm/backend/llm/hf.py
class AsyncTransformer(AsyncLM):
    """Asynchronous wrapper around a HuggingFace causal language model with caching support.

    This class provides an asynchronous interface to HuggingFace language models with automatic batching
    and caching (output and KV) for improved efficiency.
    """

    @classmethod
    def from_name(cls, model_id, bitsandbytes_opts=None, hf_opts=None, **kwargs):
        """Create an AsyncTransformer instance from a pretrained HuggingFace model.

        Args:
            model_id (str): Model identifier in HuggingFace's model hub.
            bitsandbytes_opts (dict, optional): Additional configuration options for bitsandbytes quantization.
                Defaults to None.
            hf_opts (dict, optional): Additional configuration options for loading the HuggingFace model.
                Defaults to None.
            **kwargs: Additional arguments passed to the `AsyncTransformer` constructor

        Returns:
            (AsyncTransformer): An initialized `AsyncTransformer` instance.
        """
        if bitsandbytes_opts:
            bnb_config = BitsAndBytesConfig(**bitsandbytes_opts)
        else:
            bnb_config = None

        _hf_opts = {
            "device_map": "auto",
            "torch_dtype": "auto",
        }
        if hf_opts:
            _hf_opts.update(hf_opts)

        tok = AutoTokenizer.from_pretrained(model_id)
        mod = AutoModelForCausalLM.from_pretrained(
            model_id, quantization_config=bnb_config, **_hf_opts
        )

        return cls(mod, tok, **kwargs)

    @torch.no_grad()
    def __init__(self, hf_model, hf_tokenizer, batch_size=20, timeout=0.02):
        """Initialize an AsyncTransformer instance.

        Args:
            hf_model: A HuggingFace CausalLM model instance.
            hf_tokenizer: A HuggingFace Tokenizer.
            batch_size (int, optional): Maximum queries to process in one batch during auto-batching.
                Defaults to 20.
            timeout (float, optional): Seconds to wait since last query before processing current batch.
                Defaults to 0.02.
        """
        self.model = hf_model
        self.tokenizer = hf_tokenizer
        self.device = hf_model.device
        self.cache = TokenTrie()

        # Queries to be batched. Each query is a sequence of tokens,
        # and a Future to be called when the query is resolved.
        self.queries = []
        self.batch_size = batch_size
        self.timeout = timeout
        self.timer = None

        self.model.eval()

        super().__init__(tokenizer=self.tokenizer)

    def clear_cache(self):
        """Clear the cache of log probabilities and key/value pairs."""
        self.cache = TokenTrie()

    def clear_kv_cache(self):
        """Clear any key and value vectors from the cache."""
        self.cache.clear_kv_cache()

    def reset_async_queries(self):
        """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing
        to completion."""
        self.queries = []

    @torch.no_grad()
    def cache_kv(self, prompt_tokens):
        """Cache the key and value vectors for a prompt. Future queries that have this prompt as a prefix will only run the LLM on new tokens.

        Args:
            prompt_tokens (list[int]): token ids for the prompt to cache.
        """
        result = self.model(torch.tensor([prompt_tokens]).to(self.device))
        node = self.cache.extend_cache(0, prompt_tokens, result.logits[0], 0)
        node.past_key_values = result.past_key_values

    def add_new_lora(self, lora_path, lora_name="lora_1"):
        """Load a LoRA adapter into the base model.

        Args:
            lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
            lora_name (str): Name to assign to the loaded adapter.

        Notes:
            This does not activate the adapter immediately. Call `set_lora()` to enable the adapter.
        """
        self.model.load_adapter(lora_path, lora_name)

    def set_lora(self, lora_path=None, lora_name="lora_1"):
        """Activate a previously loaded LoRA adapter.

        Args:
            lora_name (str): Name of the LoRA adapter to activate.

        """
        if lora_name not in list(self.model.peft_config.keys()):
            raise ValueError(
                f"A LoRA adapter named '{lora_name}' has not been loaded yet. Please call add_new_lora() first to load and name your LoRA adapters."
            )

        self.clear_kv_cache()
        self.clear_cache()
        self.model.set_adapter(lora_name)

    def clear_lora(self):
        """
        Deactivate all LoRA adapters.
        """
        self.clear_kv_cache()
        self.clear_cache()
        self.model.set_adapter([])

    @torch.no_grad()
    def batch_evaluate_queries(self):
        """
        Process a batch of queued language model queries.

        This method is called internally when the `batch_size` has been met or the `timeout` has expired.
        """

        queries, self.queries = self.queries, []
        if len(queries) == 0:
            return

        query_groups = defaultdict(list)
        for query in queries:
            key = tuple(query.prompt)  # XXX: cache based on past_len too?
            query_groups[key].append(query)

        # Use one representative query from each group
        unique_queries = [group[0] for group in query_groups.values()]

        past_example = next((q.past for q in unique_queries if q.past), False)
        max_past_length = max(q.past_len for q in unique_queries)
        max_query_length = max(len(q.prompt) for q in unique_queries)

        padding_token_id = (
            self.tokenizer.pad_token_id
            if self.tokenizer.pad_token_id is not None
            else 0
        )

        input_ids = torch.tensor(
            [
                q.prompt_padded(padding_token_id, max_query_length)
                for q in unique_queries
            ]
        ).to(self.device)
        attn_masks = torch.tensor(
            [
                q.attention_mask(max_past_length, max_query_length)
                for q in unique_queries
            ]
        ).to(self.device)
        posn_ids = torch.tensor(
            [q.position_ids(max_past_length, max_query_length) for q in unique_queries]
        ).to(self.device)
        if past_example:
            pasts = [
                [
                    torch.cat(
                        (
                            *(
                                q.past_padded(
                                    layer,
                                    j,
                                    max_past_length,
                                    past_example[0][0].dtype,
                                    self.device,
                                    past_example[0][0].shape,
                                )
                                for q in unique_queries
                            ),
                        ),
                        dim=0,
                    )
                    for j in range(2)
                ]
                for layer in range(len(past_example))
            ]
        else:
            pasts = None

        pasts = DynamicCache.from_legacy_cache(pasts)

        results = self.model(
            input_ids,
            attention_mask=attn_masks,
            position_ids=posn_ids,
            past_key_values=pasts,
            use_cache=pasts is not None,
        )

        assert len(results.logits) == len(unique_queries)

        for i, q in enumerate(unique_queries):
            result = results.logits[i]
            for dup_query in query_groups[tuple(q.prompt)]:
                dup_query.future.set_result(result)

    @torch.no_grad()
    def add_query(self, query, future, past):
        """Add a query to be evaluated in the next batch.

        This method is called internally when a `next_token_logprobs` request is made.

        Args:
            query (list[int]): Token IDs representing the query prompt
            future (asyncio.Future): Future to store the result in
            past (list[tuple[torch.Tensor]]|None): Past key/value states from previous evaluation,
                or None if this is a new query
        """
        self.queries.append(Query(query, future, past))

        if self.timer:
            self.timer.cancel()
            self.timer = None
        if len(self.queries) >= self.batch_size:
            self.batch_evaluate_queries()
        else:
            self.timer = asyncio.get_running_loop().call_later(
                self.timeout, lambda: self.batch_evaluate_queries()
            )

    def walk_cache(self, token_ids):
        """Walk the cache tree to find the deepest node matching a sequence of tokens.

        Args:
            token_ids (list[int]): Sequence of token IDs to follow in the cache tree

        Returns:
            tuple:
                - CacheNode: The deepest node in the cache tree that matches the token sequence
                - int: Number of tokens matched from the start of token_ids
                - list[tuple[torch.Tensor]]|None: Past key/value states from the deepest cached node,
                    or None if no cached states were found
                - int: Base index indicating where the past states start in token_ids
        """
        # Walk while tokens can be found
        node = self.cache
        next_token_index = 0

        past = None
        base = 0
        while next_token_index < len(token_ids):
            if node.past_key_values is not None:
                past = node.past_key_values
                base = next_token_index
            if node.has_token(token_ids[next_token_index]):
                node = node.get_token(token_ids[next_token_index])
                next_token_index += 1
            else:
                break

        return node, next_token_index, past, base

    @torch.no_grad()
    async def next_token_logprobs(self, token_ids):
        """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`.

        Args:
            token_ids (list[int]): a list of token ids, representing a prompt to the language model.

        Returns:
            logprobs (torch.Tensor): a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        node, next_token_index, past, base = self.walk_cache(token_ids)

        # If we processed all tokens, then we're done.
        if next_token_index == len(token_ids):
            return node.logprobs

        # Create a future with the prompt
        future = asyncio.get_running_loop().create_future()
        self.add_query(token_ids[base:], future, past)
        logits = await future

        # Create new nodes
        node = node.extend_cache(next_token_index, token_ids, logits, base)

        return node.logprobs

    @torch.no_grad()
    def next_token_logprobs_sync(self, token_ids):
        """Request log probabilities of next token. Not asynchronous, and does not support auto-batching.

        Args:
            token_ids (list[int]): a list of token ids, representing a prompt to the language model.

        Returns:
            logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        # Walk while tokens can be found
        node, next_token_index, past, base = self.walk_cache(token_ids)

        if next_token_index == len(token_ids):
            return node.logprobs

        logits = self.model(
            torch.tensor([token_ids[base:]]).to(self.device),
            past_key_values=node.past_key_values,
            use_cache=node.past_key_values is not None,
        ).logits[0]

        node = node.extend_cache(next_token_index, token_ids, logits, base)

        return node.logprobs

    def next_token_logprobs_uncached(self, token_ids):
        """Request log probabilities of next token. No KV or output caching, and does not support auto-batching.

        Args:
            token_ids (list[int]): a list of token ids, representing a prompt to the language model.

        Returns:
            logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        with torch.no_grad():
            logits = self.model(
                torch.tensor([token_ids]).to(self.device),
                past_key_values=None,
                use_cache=False,
            ).logits[0]
            return torch.log_softmax(logits[-1], dim=0)

from_name(model_id, bitsandbytes_opts=None, hf_opts=None, **kwargs) classmethod

Create an AsyncTransformer instance from a pretrained HuggingFace model.

Parameters:

Name Type Description Default
model_id str

Model identifier in HuggingFace's model hub.

required
bitsandbytes_opts dict

Additional configuration options for bitsandbytes quantization. Defaults to None.

None
hf_opts dict

Additional configuration options for loading the HuggingFace model. Defaults to None.

None
**kwargs

Additional arguments passed to the AsyncTransformer constructor

{}

Returns:

Type Description
AsyncTransformer

An initialized AsyncTransformer instance.

Source code in genlm/backend/llm/hf.py
@classmethod
def from_name(cls, model_id, bitsandbytes_opts=None, hf_opts=None, **kwargs):
    """Create an AsyncTransformer instance from a pretrained HuggingFace model.

    Args:
        model_id (str): Model identifier in HuggingFace's model hub.
        bitsandbytes_opts (dict, optional): Additional configuration options for bitsandbytes quantization.
            Defaults to None.
        hf_opts (dict, optional): Additional configuration options for loading the HuggingFace model.
            Defaults to None.
        **kwargs: Additional arguments passed to the `AsyncTransformer` constructor

    Returns:
        (AsyncTransformer): An initialized `AsyncTransformer` instance.
    """
    if bitsandbytes_opts:
        bnb_config = BitsAndBytesConfig(**bitsandbytes_opts)
    else:
        bnb_config = None

    _hf_opts = {
        "device_map": "auto",
        "torch_dtype": "auto",
    }
    if hf_opts:
        _hf_opts.update(hf_opts)

    tok = AutoTokenizer.from_pretrained(model_id)
    mod = AutoModelForCausalLM.from_pretrained(
        model_id, quantization_config=bnb_config, **_hf_opts
    )

    return cls(mod, tok, **kwargs)

__init__(hf_model, hf_tokenizer, batch_size=20, timeout=0.02)

Initialize an AsyncTransformer instance.

Parameters:

Name Type Description Default
hf_model

A HuggingFace CausalLM model instance.

required
hf_tokenizer

A HuggingFace Tokenizer.

required
batch_size int

Maximum queries to process in one batch during auto-batching. Defaults to 20.

20
timeout float

Seconds to wait since last query before processing current batch. Defaults to 0.02.

0.02
Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def __init__(self, hf_model, hf_tokenizer, batch_size=20, timeout=0.02):
    """Initialize an AsyncTransformer instance.

    Args:
        hf_model: A HuggingFace CausalLM model instance.
        hf_tokenizer: A HuggingFace Tokenizer.
        batch_size (int, optional): Maximum queries to process in one batch during auto-batching.
            Defaults to 20.
        timeout (float, optional): Seconds to wait since last query before processing current batch.
            Defaults to 0.02.
    """
    self.model = hf_model
    self.tokenizer = hf_tokenizer
    self.device = hf_model.device
    self.cache = TokenTrie()

    # Queries to be batched. Each query is a sequence of tokens,
    # and a Future to be called when the query is resolved.
    self.queries = []
    self.batch_size = batch_size
    self.timeout = timeout
    self.timer = None

    self.model.eval()

    super().__init__(tokenizer=self.tokenizer)

clear_cache()

Clear the cache of log probabilities and key/value pairs.

Source code in genlm/backend/llm/hf.py
def clear_cache(self):
    """Clear the cache of log probabilities and key/value pairs."""
    self.cache = TokenTrie()

clear_kv_cache()

Clear any key and value vectors from the cache.

Source code in genlm/backend/llm/hf.py
def clear_kv_cache(self):
    """Clear any key and value vectors from the cache."""
    self.cache.clear_kv_cache()

reset_async_queries()

Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing to completion.

Source code in genlm/backend/llm/hf.py
def reset_async_queries(self):
    """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing
    to completion."""
    self.queries = []

cache_kv(prompt_tokens)

Cache the key and value vectors for a prompt. Future queries that have this prompt as a prefix will only run the LLM on new tokens.

Parameters:

Name Type Description Default
prompt_tokens list[int]

token ids for the prompt to cache.

required
Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def cache_kv(self, prompt_tokens):
    """Cache the key and value vectors for a prompt. Future queries that have this prompt as a prefix will only run the LLM on new tokens.

    Args:
        prompt_tokens (list[int]): token ids for the prompt to cache.
    """
    result = self.model(torch.tensor([prompt_tokens]).to(self.device))
    node = self.cache.extend_cache(0, prompt_tokens, result.logits[0], 0)
    node.past_key_values = result.past_key_values

add_new_lora(lora_path, lora_name='lora_1')

Load a LoRA adapter into the base model.

Parameters:

Name Type Description Default
lora_path str

Path to the adapter weights directory or identifier in HuggingFace's model hub.

required
lora_name str

Name to assign to the loaded adapter.

'lora_1'
Notes

This does not activate the adapter immediately. Call set_lora() to enable the adapter.

Source code in genlm/backend/llm/hf.py
def add_new_lora(self, lora_path, lora_name="lora_1"):
    """Load a LoRA adapter into the base model.

    Args:
        lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
        lora_name (str): Name to assign to the loaded adapter.

    Notes:
        This does not activate the adapter immediately. Call `set_lora()` to enable the adapter.
    """
    self.model.load_adapter(lora_path, lora_name)

set_lora(lora_path=None, lora_name='lora_1')

Activate a previously loaded LoRA adapter.

Parameters:

Name Type Description Default
lora_name str

Name of the LoRA adapter to activate.

'lora_1'
Source code in genlm/backend/llm/hf.py
def set_lora(self, lora_path=None, lora_name="lora_1"):
    """Activate a previously loaded LoRA adapter.

    Args:
        lora_name (str): Name of the LoRA adapter to activate.

    """
    if lora_name not in list(self.model.peft_config.keys()):
        raise ValueError(
            f"A LoRA adapter named '{lora_name}' has not been loaded yet. Please call add_new_lora() first to load and name your LoRA adapters."
        )

    self.clear_kv_cache()
    self.clear_cache()
    self.model.set_adapter(lora_name)

clear_lora()

Deactivate all LoRA adapters.

Source code in genlm/backend/llm/hf.py
def clear_lora(self):
    """
    Deactivate all LoRA adapters.
    """
    self.clear_kv_cache()
    self.clear_cache()
    self.model.set_adapter([])

batch_evaluate_queries()

Process a batch of queued language model queries.

This method is called internally when the batch_size has been met or the timeout has expired.

Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def batch_evaluate_queries(self):
    """
    Process a batch of queued language model queries.

    This method is called internally when the `batch_size` has been met or the `timeout` has expired.
    """

    queries, self.queries = self.queries, []
    if len(queries) == 0:
        return

    query_groups = defaultdict(list)
    for query in queries:
        key = tuple(query.prompt)  # XXX: cache based on past_len too?
        query_groups[key].append(query)

    # Use one representative query from each group
    unique_queries = [group[0] for group in query_groups.values()]

    past_example = next((q.past for q in unique_queries if q.past), False)
    max_past_length = max(q.past_len for q in unique_queries)
    max_query_length = max(len(q.prompt) for q in unique_queries)

    padding_token_id = (
        self.tokenizer.pad_token_id
        if self.tokenizer.pad_token_id is not None
        else 0
    )

    input_ids = torch.tensor(
        [
            q.prompt_padded(padding_token_id, max_query_length)
            for q in unique_queries
        ]
    ).to(self.device)
    attn_masks = torch.tensor(
        [
            q.attention_mask(max_past_length, max_query_length)
            for q in unique_queries
        ]
    ).to(self.device)
    posn_ids = torch.tensor(
        [q.position_ids(max_past_length, max_query_length) for q in unique_queries]
    ).to(self.device)
    if past_example:
        pasts = [
            [
                torch.cat(
                    (
                        *(
                            q.past_padded(
                                layer,
                                j,
                                max_past_length,
                                past_example[0][0].dtype,
                                self.device,
                                past_example[0][0].shape,
                            )
                            for q in unique_queries
                        ),
                    ),
                    dim=0,
                )
                for j in range(2)
            ]
            for layer in range(len(past_example))
        ]
    else:
        pasts = None

    pasts = DynamicCache.from_legacy_cache(pasts)

    results = self.model(
        input_ids,
        attention_mask=attn_masks,
        position_ids=posn_ids,
        past_key_values=pasts,
        use_cache=pasts is not None,
    )

    assert len(results.logits) == len(unique_queries)

    for i, q in enumerate(unique_queries):
        result = results.logits[i]
        for dup_query in query_groups[tuple(q.prompt)]:
            dup_query.future.set_result(result)

add_query(query, future, past)

Add a query to be evaluated in the next batch.

This method is called internally when a next_token_logprobs request is made.

Parameters:

Name Type Description Default
query list[int]

Token IDs representing the query prompt

required
future Future

Future to store the result in

required
past list[tuple[Tensor]] | None

Past key/value states from previous evaluation, or None if this is a new query

required
Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def add_query(self, query, future, past):
    """Add a query to be evaluated in the next batch.

    This method is called internally when a `next_token_logprobs` request is made.

    Args:
        query (list[int]): Token IDs representing the query prompt
        future (asyncio.Future): Future to store the result in
        past (list[tuple[torch.Tensor]]|None): Past key/value states from previous evaluation,
            or None if this is a new query
    """
    self.queries.append(Query(query, future, past))

    if self.timer:
        self.timer.cancel()
        self.timer = None
    if len(self.queries) >= self.batch_size:
        self.batch_evaluate_queries()
    else:
        self.timer = asyncio.get_running_loop().call_later(
            self.timeout, lambda: self.batch_evaluate_queries()
        )

walk_cache(token_ids)

Walk the cache tree to find the deepest node matching a sequence of tokens.

Parameters:

Name Type Description Default
token_ids list[int]

Sequence of token IDs to follow in the cache tree

required

Returns:

Name Type Description
tuple
  • CacheNode: The deepest node in the cache tree that matches the token sequence
  • int: Number of tokens matched from the start of token_ids
  • list[tuple[torch.Tensor]]|None: Past key/value states from the deepest cached node, or None if no cached states were found
  • int: Base index indicating where the past states start in token_ids
Source code in genlm/backend/llm/hf.py
def walk_cache(self, token_ids):
    """Walk the cache tree to find the deepest node matching a sequence of tokens.

    Args:
        token_ids (list[int]): Sequence of token IDs to follow in the cache tree

    Returns:
        tuple:
            - CacheNode: The deepest node in the cache tree that matches the token sequence
            - int: Number of tokens matched from the start of token_ids
            - list[tuple[torch.Tensor]]|None: Past key/value states from the deepest cached node,
                or None if no cached states were found
            - int: Base index indicating where the past states start in token_ids
    """
    # Walk while tokens can be found
    node = self.cache
    next_token_index = 0

    past = None
    base = 0
    while next_token_index < len(token_ids):
        if node.past_key_values is not None:
            past = node.past_key_values
            base = next_token_index
        if node.has_token(token_ids[next_token_index]):
            node = node.get_token(token_ids[next_token_index])
            next_token_index += 1
        else:
            break

    return node, next_token_index, past, base

next_token_logprobs(token_ids) async

Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with await.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs Tensor

a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in genlm/backend/llm/hf.py
@torch.no_grad()
async def next_token_logprobs(self, token_ids):
    """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`.

    Args:
        token_ids (list[int]): a list of token ids, representing a prompt to the language model.

    Returns:
        logprobs (torch.Tensor): a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    node, next_token_index, past, base = self.walk_cache(token_ids)

    # If we processed all tokens, then we're done.
    if next_token_index == len(token_ids):
        return node.logprobs

    # Create a future with the prompt
    future = asyncio.get_running_loop().create_future()
    self.add_query(token_ids[base:], future, past)
    logits = await future

    # Create new nodes
    node = node.extend_cache(next_token_index, token_ids, logits, base)

    return node.logprobs

next_token_logprobs_sync(token_ids)

Request log probabilities of next token. Not asynchronous, and does not support auto-batching.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs Tensor

a tensor with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def next_token_logprobs_sync(self, token_ids):
    """Request log probabilities of next token. Not asynchronous, and does not support auto-batching.

    Args:
        token_ids (list[int]): a list of token ids, representing a prompt to the language model.

    Returns:
        logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    # Walk while tokens can be found
    node, next_token_index, past, base = self.walk_cache(token_ids)

    if next_token_index == len(token_ids):
        return node.logprobs

    logits = self.model(
        torch.tensor([token_ids[base:]]).to(self.device),
        past_key_values=node.past_key_values,
        use_cache=node.past_key_values is not None,
    ).logits[0]

    node = node.extend_cache(next_token_index, token_ids, logits, base)

    return node.logprobs

next_token_logprobs_uncached(token_ids)

Request log probabilities of next token. No KV or output caching, and does not support auto-batching.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs Tensor

a tensor with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in genlm/backend/llm/hf.py
def next_token_logprobs_uncached(self, token_ids):
    """Request log probabilities of next token. No KV or output caching, and does not support auto-batching.

    Args:
        token_ids (list[int]): a list of token ids, representing a prompt to the language model.

    Returns:
        logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    with torch.no_grad():
        logits = self.model(
            torch.tensor([token_ids]).to(self.device),
            past_key_values=None,
            use_cache=False,
        ).logits[0]
        return torch.log_softmax(logits[-1], dim=0)

load_model_by_name(name, backend=None, llm_opts=None)

Load a language model by name.

Parameters:

Name Type Description Default
name str

Hugging Face model name (e.g. "gpt2", "meta-llama/Llama-3.2-1B-Instruct")

required
backend str

Backend to use for inference. Can be "vllm", "hf", "mlx", "sgl", or "mock". If None, defaults to "vllm" if CUDA is available, otherwise "hf".

None
llm_opts dict

Additional options to pass to the backend constructor. See AsyncVirtualLM and AsyncTransformer documentation for details.

None

Returns:

Type Description
AsyncLM

An asynchronous language model.

Raises:

Type Description
ValueError

If an invalid backend is specified.

Source code in genlm/backend/llm/__init__.py
def load_model_by_name(name, backend=None, llm_opts=None):
    """Load a language model by name.

    Args:
        name (str): Hugging Face model name (e.g. "gpt2", "meta-llama/Llama-3.2-1B-Instruct")
        backend (str, optional): Backend to use for inference. Can be "vllm", "hf", "mlx", "sgl", or "mock".
            If None, defaults to "vllm" if CUDA is available, otherwise "hf".
        llm_opts (dict, optional): Additional options to pass to the backend constructor.
            See AsyncVirtualLM and AsyncTransformer documentation for details.

    Returns:
        (AsyncLM): An asynchronous language model.

    Raises:
        (ValueError): If an invalid backend is specified.
    """
    if backend is None:
        backend = "vllm" if torch.cuda.is_available() else "hf"

    if llm_opts is None:
        llm_opts = {}

    if backend == "vllm":
        return AsyncVirtualLM.from_name(name, **llm_opts)
    elif backend == "hf":
        return AsyncTransformer.from_name(name, **llm_opts)
    elif backend == "mock":
        return MockAsyncLM.from_name(name, **llm_opts)
    elif backend == "mlx":
        return AsyncMlxLM.from_name(name, **llm_opts)
    elif backend == "sgl":
        return AsyncSGLTransformer.from_name(name, **llm_opts)
    else:
        raise ValueError(f"Invalid backend: {backend}")

decode_vocab(tokenizer, byte2str_fallback='tokenizer')

Convert tokenizer vocabulary into byte and string representations.

Warning

The byte representation is the canonical form. Each element in byte_vocab is a Token object that contains both the token_id and byte_string. The string representation is provided for convenience but may not decode properly for all tokens, especially those containing invalid UTF-8 sequences.

Parameters:

Name Type Description Default
tokenizer

A Hugging Face tokenizer instance

required
byte2str_fallback str

Strategy for converting invalid UTF-8 bytes to strings. Options:

  • 'tokenizer': Use tokenizer's convert_ids_to_tokens (default)
  • 'latin1': Decode using latin1 encoding
  • 'replace': Use Unicode replacement character '�'
'tokenizer'

Returns:

Type Description
tuple

(byte_vocab, str_vocab) where byte_vocab is a list of Token objects and str_vocab is a list of strings

Source code in genlm/backend/tokenization/vocab.py
def decode_vocab(tokenizer, byte2str_fallback="tokenizer"):
    """Convert tokenizer vocabulary into byte and string representations.

    Warning:
        The byte representation is the canonical form. Each element in byte_vocab is a Token object that
        contains both the token_id and byte_string. The string representation is provided for convenience
        but may not decode properly for all tokens, especially those containing invalid UTF-8 sequences.

    Args:
        tokenizer: A Hugging Face tokenizer instance
        byte2str_fallback (str): Strategy for converting invalid UTF-8 bytes to strings. Options:\n
            - 'tokenizer': Use tokenizer's `convert_ids_to_tokens` (default)
            - 'latin1': Decode using latin1 encoding
            - 'replace': Use Unicode replacement character '�'

    Returns:
        (tuple): (byte_vocab, str_vocab) where byte_vocab is a list of Token objects
            and str_vocab is a list of strings
    """
    if byte2str_fallback not in ["latin1", "tokenizer", "replace"]:
        raise ValueError(f"Unknown byte2str_fallback strategy: {byte2str_fallback}")

    if tokenizer.is_fast:
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer.name_or_path, use_fast=False
        )

    # Try slow tokenizer.
    try:
        raw_byte_vocab = get_byte_vocab(tokenizer)
    except ByteVocabError:
        # warnings.warn("Could not decode vocabulary from slow tokenizer. Trying using fast tokenizer.")

        # Try fast tokenizer.
        tokenizer = AutoTokenizer.from_pretrained(tokenizer.name_or_path, use_fast=True)
        try:
            raw_byte_vocab = get_byte_vocab(tokenizer)
        except ByteVocabError as e:
            raise ValueError(
                f"Could not decode byte representation of token vocabuary from tokenizer {tokenizer.name_or_path}"
            ) from e

    # Create Token objects for byte_vocab.
    # Assumption: token_id == position index in the vocabulary. This is relied upon
    # by the trie (idx_to_leaf) and weight arrays (ws[i] corresponds to decode[i]).
    byte_vocab = [
        Token(token_id=i, byte_string=b) for i, b in enumerate(raw_byte_vocab)
    ]
    str_vocab = bytes_to_strs(tokenizer, raw_byte_vocab, byte2str_fallback)

    return byte_vocab, str_vocab

Token

Bases: bytes

A vocabulary token carrying both a token ID and its byte representation.

Subclasses bytes so that existing code using byte operations (b"".join, len, indexing, .decode()) continues to work. Equality and hashing between Token objects use token_id, not byte content.

Parameters:

Name Type Description Default
token_id int

The unique identifier for this token in the vocabulary.

required
byte_string bytes

The byte representation of this token.

required
Source code in genlm/backend/tokenization/token.py
class Token(bytes):
    """A vocabulary token carrying both a token ID and its byte representation.

    Subclasses ``bytes`` so that existing code using byte operations (``b"".join``,
    ``len``, indexing, ``.decode()``) continues to work. Equality and hashing
    between Token objects use ``token_id``, not byte content.

    Args:
        token_id (int): The unique identifier for this token in the vocabulary.
        byte_string (bytes): The byte representation of this token.
    """

    def __new__(cls, token_id: int, byte_string: bytes):
        if not isinstance(token_id, int):
            raise TypeError(f"token_id must be an int, got {type(token_id)}")
        if not isinstance(byte_string, bytes):
            raise TypeError(f"byte_string must be bytes, got {type(byte_string)}")
        obj = super().__new__(cls, byte_string)
        obj.token_id = token_id
        return obj

    @property
    def byte_string(self):
        """The byte representation of this token (as plain bytes)."""
        return bytes(self)

    def __repr__(self):
        return f"Token(token_id={self.token_id}, byte_string={bytes(self)!r})"

    # -- Equality / hashing: by token_id between Tokens, by content vs bytes --

    def __eq__(self, other):
        if isinstance(other, Token):
            return self.token_id == other.token_id
        return NotImplemented

    def __ne__(self, other):
        if isinstance(other, Token):
            return self.token_id != other.token_id
        return NotImplemented

    def __hash__(self):
        return hash(self.token_id)

    # -- Ordering: by token_id --

    def __lt__(self, other):
        if not isinstance(other, Token):
            return NotImplemented
        return self.token_id < other.token_id

    def __le__(self, other):
        if not isinstance(other, Token):
            return NotImplemented
        return self.token_id <= other.token_id

    def __gt__(self, other):
        if not isinstance(other, Token):
            return NotImplemented
        return self.token_id > other.token_id

    def __ge__(self, other):
        if not isinstance(other, Token):
            return NotImplemented
        return self.token_id >= other.token_id

    # -- Helpers --

    @staticmethod
    def as_bytes(x):
        """Extract byte string from a Token or pass through plain bytes."""
        return x.byte_string if isinstance(x, Token) else x

    @staticmethod
    def is_plain_bytes(x):
        """Check if x is plain bytes (not a Token)."""
        return isinstance(x, bytes) and not isinstance(x, Token)

    # -- Pickle / deepcopy support --

    def __reduce__(self):
        return (Token, (self.token_id, bytes(self)))

byte_string property

The byte representation of this token (as plain bytes).

as_bytes(x) staticmethod

Extract byte string from a Token or pass through plain bytes.

Source code in genlm/backend/tokenization/token.py
@staticmethod
def as_bytes(x):
    """Extract byte string from a Token or pass through plain bytes."""
    return x.byte_string if isinstance(x, Token) else x

is_plain_bytes(x) staticmethod

Check if x is plain bytes (not a Token).

Source code in genlm/backend/tokenization/token.py
@staticmethod
def is_plain_bytes(x):
    """Check if x is plain bytes (not a Token)."""
    return isinstance(x, bytes) and not isinstance(x, Token)

TokenCharacterTrie

A trie data structure for efficient token-to-character mapping.

Source code in genlm/backend/trie/base.py
class TokenCharacterTrie:
    """A trie data structure for efficient token-to-character mapping."""

    def __init__(self, decode):
        """Initialize a `TokenCharacterTrie`.

        Args:
            decode (list): List representing the token vocabulary.
                Each element must be iterable. Token objects use byte_string for iteration,
                other iterables (bytes, EndOfSequence) are iterated directly.
        """
        self.decode = decode
        self.word2leaf = {}
        self.children = [{}]  # First node is root
        self.root = 0
        # Maps position index in decode to leaf node: idx_to_leaf[k] = (idx, leaf_node)
        self.idx_to_leaf = []

        _warned_bytes = False
        for idx, item in enumerate(self.decode):
            # Get the word (bytes to iterate) and a unique key for word2leaf
            if isinstance(item, Token):
                word = item.byte_string
                word_key = (item.byte_string, item.token_id)
            elif Token.is_plain_bytes(item):
                if not _warned_bytes:
                    warnings.warn(
                        "Passing plain bytes to TokenCharacterTrie is deprecated. "
                        "Use Token objects from decode_vocab() instead.",
                        DeprecationWarning,
                        stacklevel=2,
                    )
                    _warned_bytes = True
                word = item
                word_key = item
            else:
                # For other iterables (e.g. EndOfSequence), iterate directly
                word = item
                word_key = item

            curr = self.root
            for letter in word:
                if letter not in self.children[curr]:
                    self.children[curr][letter] = len(self.children)
                    self.children.append({})
                curr = self.children[curr][letter]

            # Each item gets its own leaf
            leaf_edge_key = (None, idx)  # Use position index for uniqueness

            self.children[curr][leaf_edge_key] = last = len(self.children)
            self.children.append({})

            if word_key in self.word2leaf:
                raise ValueError(f"Duplicate word in vocabulary: {word_key}")
            self.word2leaf[word_key] = last

            # Use position index for weight array indexing
            self.idx_to_leaf.append((idx, last))

        self.leaf2word = dict(zip(self.word2leaf.values(), self.word2leaf.keys()))
        self.jump = List(
            [np.array(sorted(x.values()), dtype=np.int32) for x in self.children]
        )
        self.ordering = np.array(list(self._order(self.root)), np.int32)

        # Renumber the states of the trie so that they are named by a contiguous
        # range of integers and those integers respect the are topologically
        # ordering of the trie topology.  This improves the efficiency of the
        # updating the trie as it improves memory locality.
        ordering = {}
        for i, x in enumerate(self._order_full(self.root)):
            ordering[x] = i
        self._rename(f=lambda x: ordering[x])

        node2prefix = {self.root: []}
        for x in reversed(range(len(self.children))):
            for letter, y in self.children[x].items():
                # Leaf edges are tuples of (None, token_id)
                if isinstance(letter, tuple) and letter[0] is None:
                    node2prefix[y] = node2prefix[x]
                else:
                    node2prefix[y] = node2prefix[x] + [letter]
        self.node2prefix = node2prefix

    def _rename(self, f):
        """Rename all node indices in the trie using the provided mapping function.

        Args:
            f (callable): Function that maps old node indices to new node indices
        """
        N = len(self.children)

        new_children = [{} for _ in range(N)]
        nodes = range(N)

        for x in nodes:
            for letter, y in self.children[x].items():
                new_children[f(x)][letter] = f(y)

        self.root = f(self.root)
        self.children = new_children
        self.word2leaf = {w: f(x) for w, x in self.word2leaf.items()}
        self.leaf2word = dict(zip(self.word2leaf.values(), self.word2leaf.keys()))

        self.idx_to_leaf = np.array(
            [(i, f(x)) for i, x in self.idx_to_leaf], dtype=np.int32
        )

        self.ordering = np.array([f(x) for x in self.ordering])
        self.jump = List(
            [np.array(sorted(x.values()), dtype=np.int32) for x in new_children]
        )

    def _alloc_weights(self):
        """Allocate an array to store weight values for all nodes.

        Returns:
            np.ndarray: Zero-initialized array for storing weight values
        """
        return np.zeros(len(self.children), dtype=np.float64)

    def _preprocess_ws(self, ws):
        """Preprocess the weight vector to ensure it is a numpy array and on the correct device.

        Args:
            ws (torch.Tensor|np.ndarray): Token weights over the vocabulary of shape `(len(self.decode),)`

        Returns:
            (np.ndarray): Weight vector
        """
        if isinstance(ws, torch.Tensor):
            if ws.device.type != "cpu":
                ws = ws.cpu()
            ws = ws.numpy()
        return ws

    def weight_sum(self, ws):
        """Compute weight sum for each node in the trie.

        For each node in the trie, this computes the sum of weights of all leaf nodes (tokens)
        that are descendants of that node.

        Args:
            ws (torch.Tensor|np.ndarray): Token weights indexed by position in `self.decode`,
                i.e., `ws[i]` is the weight for `self.decode[i]`. Shape: `(len(self.decode),)`

        Returns:
            (np.ndarray): Summed weights for each node in the trie.
        """
        ws = self._preprocess_ws(ws)
        node_ws = self._alloc_weights()
        _update_trie_numba_sum(
            node_ws=node_ws,
            ws=ws,
            idx_to_leaf=self.idx_to_leaf,
            jump=self.jump,
            ordering=self.ordering,
        )
        return node_ws

    def weight_max(self, ws):
        """Compute weight max for each node in the trie.

        For each node in the trie, this computes the maximum weight among all leaf nodes (tokens)
        that are descendants of that node.

        Args:
            ws (torch.Tensor|np.ndarray): Token weights indexed by position in `self.decode`,
                i.e., `ws[i]` is the weight for `self.decode[i]`. Shape: `(len(self.decode),)`

        Returns:
            (np.ndarray): Weight max values for each node in the trie.
        """
        ws = self._preprocess_ws(ws)
        node_ws = self._alloc_weights()
        _update_trie_numba_max(
            node_ws=node_ws,
            ws=ws,
            idx_to_leaf=self.idx_to_leaf,
            jump=self.jump,
            ordering=self.ordering,
        )
        return node_ws

    def batch_weight_sum(self, ws):
        """Batched equivalent of `weight_sum`.

        Args:
            ws (list[torch.Tensor|np.ndarray]): Batch of token weights, each indexed by
                position in `self.decode`. Shape of each: `(len(self.decode),)`

        Returns:
            (np.ndarray): Batch of weight values of `len(ws)` for each node in the trie
        """
        return np.array([self.weight_sum(w) for w in ws])

    def batch_weight_max(self, ws):
        """Batched equivalent of `weight_max`.

        Args:
            ws (list[torch.Tensor|np.ndarray]): Batch of token weights, each indexed by
                position in `self.decode`. Shape of each: `(len(self.decode),)`

        Returns:
            (np.ndarray): Batch of weight max values of `len(ws)` for each node in the trie
        """
        return np.array([self.weight_max(w) for w in ws])

    def _order(self, node):
        """Generate a topological ordering of nodes beneath the given node.

        Args:
            node (int): Starting node index

        Yields:
            int: Node indices in topological order
        """
        for a in self.children[node]:
            # Skip leaf edges (tuples of (None, token_id))
            if isinstance(a, tuple) and a[0] is None:
                pass
            else:
                yield from self._order(self.children[node][a])
        yield node

    def _order_full(self, node):
        """Generate a complete topological ordering including all child nodes.

        Args:
            node (int): Starting node index

        Yields:
            (int): Node indices in complete topological order
        """
        for a in self.children[node]:
            yield from self._order_full(self.children[node][a])
        yield node

    def visualize(self, ws=None):
        """Visualize the trie structure using Graphviz.

        Args:
            ws (np.ndarray|None): Optional weight vector to display at each node.
                                Should be of length `len(self.children)`.

        Returns:
            (graphviz.Digraph): The generated graph object
        """
        try:
            import graphviz
        except ImportError:  # pragma: no cover
            raise ImportError(
                "Please install graphviz: pip install graphviz"
            )  # pragma: no cover

        if ws is not None and len(ws) != len(self.children):
            raise ValueError(
                f"Weight vector length ({len(ws)}) must match number of nodes ({len(self.children)})"
            )

        dot = graphviz.Digraph(comment="Token Character Trie")
        dot.attr(rankdir="LR")

        # Create a subgraph for the legend
        with dot.subgraph(name="cluster_legend") as legend:
            legend.attr(label="Legend", fontsize="10")
            legend.attr("node", fontsize="7", width="0.1", height="0.1")

            # Example internal node
            legend.node(
                "legend_internal",
                "Internal Node ID\n'Prefix'\nWeight (if provided)",
                shape="circle",
            )

            # Example leaf node
            legend.node("legend_leaf", "Complete Token", shape="doublecircle")

            legend.edge(
                "legend_internal",
                "legend_leaf",
                label="Token item",
                fontsize="10",
            )

            # Align legend horizontally
            legend.attr(rankdir="TB")
            legend.attr(rank="same")

        # Add the main trie nodes and edges
        for node_id in range(len(self.children)):
            prefix = self.node2prefix[node_id]

            if ws is not None:
                label = f"{node_id}\n'{prefix}'\n{ws[node_id]:.4f}"
            else:
                label = f"{node_id}\n'{prefix}'"

            # Color nodes based on mass if provided
            if ws is not None:
                max_ws = ws.max()
                if max_ws > 0:
                    intensity = int(255 * (1 - ws[node_id] / max_ws))
                    color = f"#{intensity:02x}{255:02x}{intensity:02x}"
                else:
                    color = "#ffffff"  # white for zero mass
            else:
                color = "#ffffff"  # default white

            if node_id in self.leaf2word:
                dot.node(
                    str(node_id),
                    label,
                    shape="doublecircle",
                    style="filled",
                    fillcolor=color,
                )
            else:
                dot.node(
                    str(node_id), label, shape="circle", style="filled", fillcolor=color
                )

        for node_id, children in enumerate(self.children):
            for char, child_id in children.items():
                # Leaf edges are tuples of (None, token_id)
                if isinstance(char, tuple) and char[0] is None:
                    edge_label = f"End-of-Token (ID: {char[1]})"
                else:
                    edge_label = str(char)

                dot.edge(str(node_id), str(child_id), label=edge_label)

        return dot

__init__(decode)

Initialize a TokenCharacterTrie.

Parameters:

Name Type Description Default
decode list

List representing the token vocabulary. Each element must be iterable. Token objects use byte_string for iteration, other iterables (bytes, EndOfSequence) are iterated directly.

required
Source code in genlm/backend/trie/base.py
def __init__(self, decode):
    """Initialize a `TokenCharacterTrie`.

    Args:
        decode (list): List representing the token vocabulary.
            Each element must be iterable. Token objects use byte_string for iteration,
            other iterables (bytes, EndOfSequence) are iterated directly.
    """
    self.decode = decode
    self.word2leaf = {}
    self.children = [{}]  # First node is root
    self.root = 0
    # Maps position index in decode to leaf node: idx_to_leaf[k] = (idx, leaf_node)
    self.idx_to_leaf = []

    _warned_bytes = False
    for idx, item in enumerate(self.decode):
        # Get the word (bytes to iterate) and a unique key for word2leaf
        if isinstance(item, Token):
            word = item.byte_string
            word_key = (item.byte_string, item.token_id)
        elif Token.is_plain_bytes(item):
            if not _warned_bytes:
                warnings.warn(
                    "Passing plain bytes to TokenCharacterTrie is deprecated. "
                    "Use Token objects from decode_vocab() instead.",
                    DeprecationWarning,
                    stacklevel=2,
                )
                _warned_bytes = True
            word = item
            word_key = item
        else:
            # For other iterables (e.g. EndOfSequence), iterate directly
            word = item
            word_key = item

        curr = self.root
        for letter in word:
            if letter not in self.children[curr]:
                self.children[curr][letter] = len(self.children)
                self.children.append({})
            curr = self.children[curr][letter]

        # Each item gets its own leaf
        leaf_edge_key = (None, idx)  # Use position index for uniqueness

        self.children[curr][leaf_edge_key] = last = len(self.children)
        self.children.append({})

        if word_key in self.word2leaf:
            raise ValueError(f"Duplicate word in vocabulary: {word_key}")
        self.word2leaf[word_key] = last

        # Use position index for weight array indexing
        self.idx_to_leaf.append((idx, last))

    self.leaf2word = dict(zip(self.word2leaf.values(), self.word2leaf.keys()))
    self.jump = List(
        [np.array(sorted(x.values()), dtype=np.int32) for x in self.children]
    )
    self.ordering = np.array(list(self._order(self.root)), np.int32)

    # Renumber the states of the trie so that they are named by a contiguous
    # range of integers and those integers respect the are topologically
    # ordering of the trie topology.  This improves the efficiency of the
    # updating the trie as it improves memory locality.
    ordering = {}
    for i, x in enumerate(self._order_full(self.root)):
        ordering[x] = i
    self._rename(f=lambda x: ordering[x])

    node2prefix = {self.root: []}
    for x in reversed(range(len(self.children))):
        for letter, y in self.children[x].items():
            # Leaf edges are tuples of (None, token_id)
            if isinstance(letter, tuple) and letter[0] is None:
                node2prefix[y] = node2prefix[x]
            else:
                node2prefix[y] = node2prefix[x] + [letter]
    self.node2prefix = node2prefix

weight_sum(ws)

Compute weight sum for each node in the trie.

For each node in the trie, this computes the sum of weights of all leaf nodes (tokens) that are descendants of that node.

Parameters:

Name Type Description Default
ws Tensor | ndarray

Token weights indexed by position in self.decode, i.e., ws[i] is the weight for self.decode[i]. Shape: (len(self.decode),)

required

Returns:

Type Description
ndarray

Summed weights for each node in the trie.

Source code in genlm/backend/trie/base.py
def weight_sum(self, ws):
    """Compute weight sum for each node in the trie.

    For each node in the trie, this computes the sum of weights of all leaf nodes (tokens)
    that are descendants of that node.

    Args:
        ws (torch.Tensor|np.ndarray): Token weights indexed by position in `self.decode`,
            i.e., `ws[i]` is the weight for `self.decode[i]`. Shape: `(len(self.decode),)`

    Returns:
        (np.ndarray): Summed weights for each node in the trie.
    """
    ws = self._preprocess_ws(ws)
    node_ws = self._alloc_weights()
    _update_trie_numba_sum(
        node_ws=node_ws,
        ws=ws,
        idx_to_leaf=self.idx_to_leaf,
        jump=self.jump,
        ordering=self.ordering,
    )
    return node_ws

weight_max(ws)

Compute weight max for each node in the trie.

For each node in the trie, this computes the maximum weight among all leaf nodes (tokens) that are descendants of that node.

Parameters:

Name Type Description Default
ws Tensor | ndarray

Token weights indexed by position in self.decode, i.e., ws[i] is the weight for self.decode[i]. Shape: (len(self.decode),)

required

Returns:

Type Description
ndarray

Weight max values for each node in the trie.

Source code in genlm/backend/trie/base.py
def weight_max(self, ws):
    """Compute weight max for each node in the trie.

    For each node in the trie, this computes the maximum weight among all leaf nodes (tokens)
    that are descendants of that node.

    Args:
        ws (torch.Tensor|np.ndarray): Token weights indexed by position in `self.decode`,
            i.e., `ws[i]` is the weight for `self.decode[i]`. Shape: `(len(self.decode),)`

    Returns:
        (np.ndarray): Weight max values for each node in the trie.
    """
    ws = self._preprocess_ws(ws)
    node_ws = self._alloc_weights()
    _update_trie_numba_max(
        node_ws=node_ws,
        ws=ws,
        idx_to_leaf=self.idx_to_leaf,
        jump=self.jump,
        ordering=self.ordering,
    )
    return node_ws

batch_weight_sum(ws)

Batched equivalent of weight_sum.

Parameters:

Name Type Description Default
ws list[Tensor | ndarray]

Batch of token weights, each indexed by position in self.decode. Shape of each: (len(self.decode),)

required

Returns:

Type Description
ndarray

Batch of weight values of len(ws) for each node in the trie

Source code in genlm/backend/trie/base.py
def batch_weight_sum(self, ws):
    """Batched equivalent of `weight_sum`.

    Args:
        ws (list[torch.Tensor|np.ndarray]): Batch of token weights, each indexed by
            position in `self.decode`. Shape of each: `(len(self.decode),)`

    Returns:
        (np.ndarray): Batch of weight values of `len(ws)` for each node in the trie
    """
    return np.array([self.weight_sum(w) for w in ws])

batch_weight_max(ws)

Batched equivalent of weight_max.

Parameters:

Name Type Description Default
ws list[Tensor | ndarray]

Batch of token weights, each indexed by position in self.decode. Shape of each: (len(self.decode),)

required

Returns:

Type Description
ndarray

Batch of weight max values of len(ws) for each node in the trie

Source code in genlm/backend/trie/base.py
def batch_weight_max(self, ws):
    """Batched equivalent of `weight_max`.

    Args:
        ws (list[torch.Tensor|np.ndarray]): Batch of token weights, each indexed by
            position in `self.decode`. Shape of each: `(len(self.decode),)`

    Returns:
        (np.ndarray): Batch of weight max values of `len(ws)` for each node in the trie
    """
    return np.array([self.weight_max(w) for w in ws])

visualize(ws=None)

Visualize the trie structure using Graphviz.

Parameters:

Name Type Description Default
ws ndarray | None

Optional weight vector to display at each node. Should be of length len(self.children).

None

Returns:

Type Description
Digraph

The generated graph object

Source code in genlm/backend/trie/base.py
def visualize(self, ws=None):
    """Visualize the trie structure using Graphviz.

    Args:
        ws (np.ndarray|None): Optional weight vector to display at each node.
                            Should be of length `len(self.children)`.

    Returns:
        (graphviz.Digraph): The generated graph object
    """
    try:
        import graphviz
    except ImportError:  # pragma: no cover
        raise ImportError(
            "Please install graphviz: pip install graphviz"
        )  # pragma: no cover

    if ws is not None and len(ws) != len(self.children):
        raise ValueError(
            f"Weight vector length ({len(ws)}) must match number of nodes ({len(self.children)})"
        )

    dot = graphviz.Digraph(comment="Token Character Trie")
    dot.attr(rankdir="LR")

    # Create a subgraph for the legend
    with dot.subgraph(name="cluster_legend") as legend:
        legend.attr(label="Legend", fontsize="10")
        legend.attr("node", fontsize="7", width="0.1", height="0.1")

        # Example internal node
        legend.node(
            "legend_internal",
            "Internal Node ID\n'Prefix'\nWeight (if provided)",
            shape="circle",
        )

        # Example leaf node
        legend.node("legend_leaf", "Complete Token", shape="doublecircle")

        legend.edge(
            "legend_internal",
            "legend_leaf",
            label="Token item",
            fontsize="10",
        )

        # Align legend horizontally
        legend.attr(rankdir="TB")
        legend.attr(rank="same")

    # Add the main trie nodes and edges
    for node_id in range(len(self.children)):
        prefix = self.node2prefix[node_id]

        if ws is not None:
            label = f"{node_id}\n'{prefix}'\n{ws[node_id]:.4f}"
        else:
            label = f"{node_id}\n'{prefix}'"

        # Color nodes based on mass if provided
        if ws is not None:
            max_ws = ws.max()
            if max_ws > 0:
                intensity = int(255 * (1 - ws[node_id] / max_ws))
                color = f"#{intensity:02x}{255:02x}{intensity:02x}"
            else:
                color = "#ffffff"  # white for zero mass
        else:
            color = "#ffffff"  # default white

        if node_id in self.leaf2word:
            dot.node(
                str(node_id),
                label,
                shape="doublecircle",
                style="filled",
                fillcolor=color,
            )
        else:
            dot.node(
                str(node_id), label, shape="circle", style="filled", fillcolor=color
            )

    for node_id, children in enumerate(self.children):
        for char, child_id in children.items():
            # Leaf edges are tuples of (None, token_id)
            if isinstance(char, tuple) and char[0] is None:
                edge_label = f"End-of-Token (ID: {char[1]})"
            else:
                edge_label = str(char)

            dot.edge(str(node_id), str(child_id), label=edge_label)

    return dot

ParallelTokenCharacterTrie

Bases: TokenCharacterTrie

A GPU-optimized version of TokenCharacterTrie that performs weight sum and max operations in parallel.

Source code in genlm/backend/trie/parallel.py
class ParallelTokenCharacterTrie(TokenCharacterTrie):
    """A GPU-optimized version of `TokenCharacterTrie` that performs weight sum and max operations in parallel."""

    def __init__(self, decode, device=None, **kwargs):
        super().__init__(decode, **kwargs)

        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        if self.device not in ["cpu", "cuda"]:
            raise ValueError(f"Invalid device: {device}. Must be 'cpu', 'cuda' or None")

        self._build_reachability_matrix()
        # Position indices for weight array indexing
        self.positions = torch.tensor(
            self.idx_to_leaf[:, 0], dtype=torch.long, device=self.device
        )

    def _build_parent_map(self):
        """Builds a mapping from each node to its parent node in the trie.

        Returns:
            (dict): A dictionary where keys are child nodes and values are their parent nodes.
        """
        parent = {}
        for node in range(len(self.children)):
            for child in self.jump[node]:
                parent[child] = node
        return parent

    def _build_reachability_matrix(self):
        """Constructs a sparse reachability matrix for efficient weight propagation.

        The matrix M is constructed such that M[i,j] = 1 if node j is either:
        - The leaf node i itself (self-connection)
        - An ancestor of leaf node i in the trie
        """
        leaf_indices = self.idx_to_leaf[:, 1]
        parent = self._build_parent_map()

        rows, cols = [], []
        for i, node in enumerate(leaf_indices):
            # self connections
            rows.append(i)
            cols.append(node)

            current = node
            while current in parent:  # Walk up to root
                ancestor = parent[current]
                rows.append(i)
                cols.append(ancestor)
                current = ancestor

        self.src_indices = torch.tensor(rows, dtype=torch.long, device=self.device)
        self.dst_indices = torch.tensor(cols, dtype=torch.long, device=self.device)

        indices = torch.tensor([rows, cols], dtype=torch.long, device=self.device)
        values = torch.ones(len(rows), device=self.device)

        self.M = torch.sparse_coo_tensor(
            indices, values, (len(leaf_indices), len(self.children))
        ).to_sparse_csr()

    def _preprocess_ws(self, batch_ws):
        processed_batch_ws = []
        for ws in batch_ws:
            if not isinstance(ws, torch.Tensor):
                ws = torch.tensor(ws, device=self.device, dtype=torch.float32)
            elif ws.device != self.device or ws.dtype != torch.float32:
                ws = ws.to(device=self.device, dtype=torch.float32)
            assert ws.shape[0] == len(self.decode), [ws.shape[0], len(self.decode)]
            processed_batch_ws.append(ws)
        return torch.stack(processed_batch_ws)

    def weight_sum(self, ws):
        """Computes weight sums given token weights.

        For each node in the trie, this computes the sum of weights of all leaf nodes (tokens)
        that are descendants of that node. This is efficiently implemented using sparse matrix multiplication
        with a pre-computed reachability matrix.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).

        Returns:
            (numpy.ndarray): Summed weights for each node in the trie, shape (`len(self.decode)`,).
        """
        return self.batch_weight_sum(self._preprocess_ws([ws]))[0]

    def batch_weight_sum(self, ws):
        """Batch version of `weight_sum`.

        Args:
            ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).

        Returns:
            numpy.ndarray: Summed weights for each node in the trie, shape (batch_size × num_nodes).
        """
        ws = self._preprocess_ws(ws)
        masses = torch.sparse.mm(ws[:, self.positions], self.M)
        return masses.cpu().numpy()

    def weight_max(self, ws):
        """Computes the max weights given the token weights.

        For each node in the trie, this computes the maximum weight among all leaf nodes (tokens)
        that are descendants of that node. This is efficiently implemented using parallel scatter_reduce
        operations on GPU.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).

        Returns:
            (numpy.ndarray): Maximum weights for each node in the trie, shape (`len(self.decode)`,).
        """
        return self.batch_weight_max(self._preprocess_ws([ws]))[0]

    def batch_weight_max(self, ws):
        """Batch version of `weight_max`.

        Args:
            ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).

        Returns:
            (numpy.ndarray): Maximum weights for each node in the trie, shape (batch_size × num_nodes).
        """
        ws = self._preprocess_ws(ws)

        # Get leaf weights (indexed by position in decode)
        leaf_weights = ws[:, self.positions]  # shape: (batch_size × num_leafs)
        batch_size = leaf_weights.shape[0]

        # Use scatter_reduce to propagate maximum values in parallel
        result = torch.zeros((batch_size, len(self.children)), device=self.device)
        result.scatter_reduce_(
            dim=1,
            index=self.dst_indices.expand(batch_size, -1),
            src=leaf_weights[:, self.src_indices],
            reduce="amax",
            include_self=False,
        )

        return result.cpu().numpy()

weight_sum(ws)

Computes weight sums given token weights.

For each node in the trie, this computes the sum of weights of all leaf nodes (tokens) that are descendants of that node. This is efficiently implemented using sparse matrix multiplication with a pre-computed reachability matrix.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.decode),).

required

Returns:

Type Description
ndarray

Summed weights for each node in the trie, shape (len(self.decode),).

Source code in genlm/backend/trie/parallel.py
def weight_sum(self, ws):
    """Computes weight sums given token weights.

    For each node in the trie, this computes the sum of weights of all leaf nodes (tokens)
    that are descendants of that node. This is efficiently implemented using sparse matrix multiplication
    with a pre-computed reachability matrix.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).

    Returns:
        (numpy.ndarray): Summed weights for each node in the trie, shape (`len(self.decode)`,).
    """
    return self.batch_weight_sum(self._preprocess_ws([ws]))[0]

batch_weight_sum(ws)

Batch version of weight_sum.

Parameters:

Name Type Description Default
ws Tensor

Batch of token weights, shape (batch_size × len(self.decode)).

required

Returns:

Type Description

numpy.ndarray: Summed weights for each node in the trie, shape (batch_size × num_nodes).

Source code in genlm/backend/trie/parallel.py
def batch_weight_sum(self, ws):
    """Batch version of `weight_sum`.

    Args:
        ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).

    Returns:
        numpy.ndarray: Summed weights for each node in the trie, shape (batch_size × num_nodes).
    """
    ws = self._preprocess_ws(ws)
    masses = torch.sparse.mm(ws[:, self.positions], self.M)
    return masses.cpu().numpy()

weight_max(ws)

Computes the max weights given the token weights.

For each node in the trie, this computes the maximum weight among all leaf nodes (tokens) that are descendants of that node. This is efficiently implemented using parallel scatter_reduce operations on GPU.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.decode),).

required

Returns:

Type Description
ndarray

Maximum weights for each node in the trie, shape (len(self.decode),).

Source code in genlm/backend/trie/parallel.py
def weight_max(self, ws):
    """Computes the max weights given the token weights.

    For each node in the trie, this computes the maximum weight among all leaf nodes (tokens)
    that are descendants of that node. This is efficiently implemented using parallel scatter_reduce
    operations on GPU.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).

    Returns:
        (numpy.ndarray): Maximum weights for each node in the trie, shape (`len(self.decode)`,).
    """
    return self.batch_weight_max(self._preprocess_ws([ws]))[0]

batch_weight_max(ws)

Batch version of weight_max.

Parameters:

Name Type Description Default
ws Tensor

Batch of token weights, shape (batch_size × len(self.decode)).

required

Returns:

Type Description
ndarray

Maximum weights for each node in the trie, shape (batch_size × num_nodes).

Source code in genlm/backend/trie/parallel.py
def batch_weight_max(self, ws):
    """Batch version of `weight_max`.

    Args:
        ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).

    Returns:
        (numpy.ndarray): Maximum weights for each node in the trie, shape (batch_size × num_nodes).
    """
    ws = self._preprocess_ws(ws)

    # Get leaf weights (indexed by position in decode)
    leaf_weights = ws[:, self.positions]  # shape: (batch_size × num_leafs)
    batch_size = leaf_weights.shape[0]

    # Use scatter_reduce to propagate maximum values in parallel
    result = torch.zeros((batch_size, len(self.children)), device=self.device)
    result.scatter_reduce_(
        dim=1,
        index=self.dst_indices.expand(batch_size, -1),
        src=leaf_weights[:, self.src_indices],
        reduce="amax",
        include_self=False,
    )

    return result.cpu().numpy()

AsyncTokenCharacterTrie

An asynchronous wrapper for TokenCharacterTrie implementations that provides automatic request batching.

Source code in genlm/backend/trie/async_impl.py
class AsyncTokenCharacterTrie:
    """An asynchronous wrapper for TokenCharacterTrie implementations that provides automatic request batching."""

    def __init__(self, trie):
        """Initialize an `AsyncTokenCharacterTrie`.

        Args:
            trie (TokenCharacterTrie|ParallelTokenCharacterTrie): The underlying `TokenCharacterTrie` or `ParallelTokenCharacterTrie` instance
        """
        self.trie = trie
        self._queue = None
        self._task = None

    @classmethod
    def from_vocab(cls, vocab, backend="parallel", **kwargs):
        """Creates an `AsyncTokenCharacterTrie` from a vocabulary.

        Args:
            vocab (list[Token]): The vocabulary over which the trie will be defined.
            backend (str, optional): The trie implementation to use - either 'sequential' or 'parallel'.
                    Defaults to 'parallel' which uses GPU acceleration when available.
            **kwargs: Additional arguments passed to the trie constructor

        Returns:
            (AsyncTokenCharacterTrie): The initialized asynchronous trie instance.
        """
        if backend == "sequential":
            trie = TokenCharacterTrie(decode=vocab, **kwargs)
        elif backend == "parallel":
            trie = ParallelTokenCharacterTrie(decode=vocab, **kwargs)
        else:
            raise ValueError(
                f"Unknown backend: {backend}. Must be one of ['sequential', 'parallel']"
            )
        return cls(trie)

    async def _queue_request(self, request, op):
        if not self._task or self._task.done():
            self.start()

        future = asyncio.Future()
        await self._queue.put((request, future, op))
        return future

    async def weight_sum(self, ws):
        """Queue a `weight_sum` request. Multiple concurrent calls will be automatically batched
        together.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).

        Returns:
            (np.ndarray): The calculated mass sums for the given distribution.
        """
        future = await self._queue_request(ws, "sum")
        result = await future
        return result

    async def weight_max(self, ws):
        """Queue a `weight_max` request. Multiple concurrent calls will be automatically batched
        together.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).

        Returns:
            (np.ndarray): The calculated max weights for the given distribution.
        """
        future = await self._queue_request(ws, "max")
        result = await future
        return result

    def start(self):
        """Start the background processing task if not already running."""
        if not self._task or self._task.done():
            self._queue = (
                asyncio.Queue()
            )  # Create a new queue so that it is bound to the current event loop
            self._task = asyncio.create_task(self._background_loop())

    def _do_weight_sums(self, batch_weights):
        return self.trie.batch_weight_sum(batch_weights)

    def _do_weight_maxs(self, batch_weights):
        return self.trie.batch_weight_max(batch_weights)

    async def _background_loop(self):
        """Background task that processes queued weight sum and max requests.

        Continuously monitors the queue for new requests and processes them in batches
        using the underlying trie implementation.

        Raises:
            Exception: If any error occurs during processing, it is propagated to all
                      pending futures in the current batch.
        """
        while True:
            try:
                op_groups = defaultdict(list)

                request, future, op = await self._queue.get()
                op_groups[op].append((request, future))

                while not self._queue.empty():
                    request, future, op = await self._queue.get()
                    op_groups[op].append((request, future))

                for op, group in op_groups.items():
                    requests, futures = zip(*group)

                    if op == "sum":
                        logger.debug(f"processing {len(requests)} sum requests")
                        results = self._do_weight_sums(requests)
                    elif op == "max":
                        logger.debug(f"processing {len(requests)} max requests")
                        results = self._do_weight_maxs(requests)
                    else:
                        raise ValueError(f"Unknown operation: {op}")

                    for future, result in zip(futures, results):
                        future.set_result(result)

            except Exception as e:
                for group in op_groups.values():
                    for _, future in group:
                        if not future.done():
                            future.set_exception(e)
                raise

    async def cleanup(self):
        """Async cleanup - preferred method"""
        if self._task and not self._task.done():
            self._task.cancel()
            try:
                await self._task
            except asyncio.CancelledError:
                pass
            self._task = None

    def shutdown(self):
        """Stop the background processing task and cleanup resources."""
        if self._task is not None:
            try:
                self._task.cancel()
            except RuntimeError:
                # Ignore runtime errors that might occur if event loop is closed
                pass
            self._task = None

    def __del__(self):
        self.shutdown()

__init__(trie)

Initialize an AsyncTokenCharacterTrie.

Parameters:

Name Type Description Default
trie TokenCharacterTrie | ParallelTokenCharacterTrie

The underlying TokenCharacterTrie or ParallelTokenCharacterTrie instance

required
Source code in genlm/backend/trie/async_impl.py
def __init__(self, trie):
    """Initialize an `AsyncTokenCharacterTrie`.

    Args:
        trie (TokenCharacterTrie|ParallelTokenCharacterTrie): The underlying `TokenCharacterTrie` or `ParallelTokenCharacterTrie` instance
    """
    self.trie = trie
    self._queue = None
    self._task = None

from_vocab(vocab, backend='parallel', **kwargs) classmethod

Creates an AsyncTokenCharacterTrie from a vocabulary.

Parameters:

Name Type Description Default
vocab list[Token]

The vocabulary over which the trie will be defined.

required
backend str

The trie implementation to use - either 'sequential' or 'parallel'. Defaults to 'parallel' which uses GPU acceleration when available.

'parallel'
**kwargs

Additional arguments passed to the trie constructor

{}

Returns:

Type Description
AsyncTokenCharacterTrie

The initialized asynchronous trie instance.

Source code in genlm/backend/trie/async_impl.py
@classmethod
def from_vocab(cls, vocab, backend="parallel", **kwargs):
    """Creates an `AsyncTokenCharacterTrie` from a vocabulary.

    Args:
        vocab (list[Token]): The vocabulary over which the trie will be defined.
        backend (str, optional): The trie implementation to use - either 'sequential' or 'parallel'.
                Defaults to 'parallel' which uses GPU acceleration when available.
        **kwargs: Additional arguments passed to the trie constructor

    Returns:
        (AsyncTokenCharacterTrie): The initialized asynchronous trie instance.
    """
    if backend == "sequential":
        trie = TokenCharacterTrie(decode=vocab, **kwargs)
    elif backend == "parallel":
        trie = ParallelTokenCharacterTrie(decode=vocab, **kwargs)
    else:
        raise ValueError(
            f"Unknown backend: {backend}. Must be one of ['sequential', 'parallel']"
        )
    return cls(trie)

weight_sum(ws) async

Queue a weight_sum request. Multiple concurrent calls will be automatically batched together.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.trie.decode),).

required

Returns:

Type Description
ndarray

The calculated mass sums for the given distribution.

Source code in genlm/backend/trie/async_impl.py
async def weight_sum(self, ws):
    """Queue a `weight_sum` request. Multiple concurrent calls will be automatically batched
    together.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).

    Returns:
        (np.ndarray): The calculated mass sums for the given distribution.
    """
    future = await self._queue_request(ws, "sum")
    result = await future
    return result

weight_max(ws) async

Queue a weight_max request. Multiple concurrent calls will be automatically batched together.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.trie.decode),).

required

Returns:

Type Description
ndarray

The calculated max weights for the given distribution.

Source code in genlm/backend/trie/async_impl.py
async def weight_max(self, ws):
    """Queue a `weight_max` request. Multiple concurrent calls will be automatically batched
    together.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).

    Returns:
        (np.ndarray): The calculated max weights for the given distribution.
    """
    future = await self._queue_request(ws, "max")
    result = await future
    return result

start()

Start the background processing task if not already running.

Source code in genlm/backend/trie/async_impl.py
def start(self):
    """Start the background processing task if not already running."""
    if not self._task or self._task.done():
        self._queue = (
            asyncio.Queue()
        )  # Create a new queue so that it is bound to the current event loop
        self._task = asyncio.create_task(self._background_loop())

cleanup() async

Async cleanup - preferred method

Source code in genlm/backend/trie/async_impl.py
async def cleanup(self):
    """Async cleanup - preferred method"""
    if self._task and not self._task.done():
        self._task.cancel()
        try:
            await self._task
        except asyncio.CancelledError:
            pass
        self._task = None

shutdown()

Stop the background processing task and cleanup resources.

Source code in genlm/backend/trie/async_impl.py
def shutdown(self):
    """Stop the background processing task and cleanup resources."""
    if self._task is not None:
        try:
            self._task.cancel()
        except RuntimeError:
            # Ignore runtime errors that might occur if event loop is closed
            pass
        self._task = None