Skip to content

llm

AsyncVirtualLM

Bases: AsyncLM

Async language model using vLLM v1 with global logits processor.

This implementation uses vLLM v1's in-process mode with a global logits processor to efficiently capture full vocabulary log probabilities.

Source code in genlm/backend/llm/vllm.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
class AsyncVirtualLM(AsyncLM):  # pragma: no cover
    """Async language model using vLLM v1 with global logits processor.

    This implementation uses vLLM v1's in-process mode with a global
    logits processor to efficiently capture full vocabulary log probabilities.
    """

    default_params = {
        "max_tokens": 1,
        "n": 1,
        "detokenize": False,
        "stop": None,
        "ignore_eos": True,
    }

    def __init__(
        self,
        llm_engine,
        logprobs_capture,
        cache_size=0,
        cache_opts=None,
        batch_size=20,
        timeout=0.02,
    ):
        """Initialize an `AsyncVirtualLM` instance.

        Args:
            llm_engine (LLM): The vLLM engine instance.
            logprobs_capture (GlobalLogprobsCapture): The global logprobs capture processor.
            cache_size (int, optional): Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.
            cache_opts (dict, optional): Additional options to pass to the [`OutputCache`][genlm.backend.cache.OutputCache] constructor. Defaults to None (no extra options).
            batch_size (int, optional): Maximum queries to process in one batch during auto-batching. Defaults to 20.
            timeout (float, optional): Seconds to wait after the first queued query before processing the current batch. The batch also fires immediately when ``batch_size`` is reached. Defaults to 0.02.

        Note:
            The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine.
            ``batch_next_token_logprobs_sync`` bypasses this cache and always re-evaluates; the other three logprobs methods consult it.
        """
        self.llm_engine = llm_engine
        self.logprobs_capture = logprobs_capture
        self.tokenizer = llm_engine.get_tokenizer()
        self.cache = (
            OutputCache(maxsize=cache_size, **(cache_opts or {}))
            if cache_size > 0
            else None
        )
        self.lora_request = None
        self.lora_name_to_ids = {}

        self.queries = []
        self.batch_size = batch_size
        self.timeout = timeout
        self.timer = None

        self.sample_queries = []
        self.sample_timer = None

        super().__init__(tokenizer=self.tokenizer)

    @classmethod
    def from_name(cls, model_name, engine_opts=None, **kwargs):
        """Create a `AsyncVirtualLM` instance from a model name.

        Args:
            model_name (str): Name of the model to load.
            engine_opts (dict): Additional options to pass to the `LLM` engine.
            **kwargs: Additional arguments passed to `AsyncVirtualLM` constructor.

        Returns:
            (AsyncVirtualLM): An `AsyncVirtualLM` instance.
        """
        if not HAS_VLLM:
            raise ImportError(  # pragma: no cover
                "vLLM not available. Install vLLM or use AsyncTransformer instead."
            )

        engine_opts = {
            "enable_prefix_caching": True,
            "disable_log_stats": True,
            "gpu_memory_utilization": 0.9,
            **(engine_opts or {}),
        }

        llm = LLM(model=model_name, tokenizer=model_name, **engine_opts)

        logprobs_capture = GlobalLogprobsCapture()
        model_runner = cls._get_model_runner(llm)
        model_runner.input_batch.logitsprocs.argmax_invariant.append(
            logprobs_capture
        )

        return cls(llm, logprobs_capture, **kwargs)

    @staticmethod
    def _get_model_runner(llm):
        """Walk the vLLM v1 internals to reach the driver worker's model runner.

        This path is brittle against vLLM refactors, so it lives in one
        place and is reused by ``from_name`` (to inject the logits
        processor) and ``underlying_model``.
        """
        engine_core = llm.llm_engine.engine_core.engine_core
        return engine_core.model_executor.driver_worker.worker.model_runner

    @property
    def underlying_model(self):
        """Access the underlying model for advanced use cases."""
        return self._get_model_runner(self.llm_engine).model

    def clear_lora(self):
        """
        Disable any active LoRA adapter for the vLLM engine.
        """
        self.lora_request = None

    def add_new_lora(self, lora_path, lora_name="lora_1"):
        """Load a LoRA adapter into the base model by creating a unique id for it.

        Args:
            lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
            lora_name (str): Name to assign to the loaded adapter.

        Notes:
            This does not activate the adapter immediately. Call `set_lora()` to enable the adapter.
        """
        self.lora_name_to_ids[lora_name] = self.hash_to_int(lora_name)

    def hash_to_int(self, value):
        """Generates a deterministic unique id for a LoRA adapter from its name.

        Args:
            value (str): The name of the LoRA adapter to hash.

        Returns:
            An integer ID corresponding to the LoRA adapter, in the range [1, 2^31 - 1].
        """
        hash_bytes = hashlib.shake_128(value.encode("utf-8")).digest(4)
        return (int.from_bytes(hash_bytes, "big") % (2**31 - 2)) + 1

    def set_lora(self, lora_path, lora_name="lora_1"):
        """Configure a LoRA adapter request for the vLLM engine.

        Args:
            lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
            lora_name (str): Identifier name to associate with this LoRA adapter within vLLM.
            lora_id (int): Globally unique ID for the adapter.
        """
        if lora_name not in self.lora_name_to_ids:
            raise ValueError(
                f"A LoRA adapter named '{lora_name}' has not been loaded yet. Please call add_new_lora() first to load and name your LoRA adapters."
            )
        self.lora_request = LoRARequest(
            lora_name, self.lora_name_to_ids[lora_name], lora_path
        )

    async def next_token_logprobs(self, token_ids):
        """Request log probabilities of next token asynchronously with auto-batching.

        Concurrent calls to this method are automatically batched into a single
        ``LLM.generate()`` call for efficiency. Use with ``await``.

        Args:
            token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

        Returns:
            result (torch.Tensor): Normalized log probability tensor.
        """
        key = tuple(token_ids)

        if self.cache is not None and key in self.cache:
            return self.cache[key]

        future = asyncio.get_running_loop().create_future()
        self._add_query(token_ids, future)
        result = await future

        if self.cache is not None:
            self.cache[key] = result

        return result

    def _add_query(self, token_ids, future):
        """Add a query to be evaluated in the next batch.

        The timeout is measured from the *first* queued query, not the most
        recent one: we only arm the timer when the queue transitions from
        empty to non-empty. This prevents starvation when queries trickle in
        faster than ``self.timeout`` but never fill a batch.

        Args:
            token_ids (list[int]): Token IDs representing the query prompt.
            future (asyncio.Future): Future to store the result in.
        """
        self.queries.append((token_ids, future))

        if len(self.queries) >= self.batch_size:
            if self.timer:
                self.timer.cancel()
                self.timer = None
            self._batch_evaluate()
        elif self.timer is None:
            self.timer = asyncio.get_running_loop().call_later(
                self.timeout, self._batch_evaluate
            )

    def _batch_evaluate(self):
        """Process all queued queries in a single batched ``generate()`` call."""
        queries, self.queries = self.queries, []
        if not queries:
            return

        if self.timer:
            self.timer.cancel()
            self.timer = None

        if self.logprobs_capture is None:
            exc = RuntimeError("Cannot use model after cleanup() has been called")
            for _, future in queries:
                future.set_exception(exc)
            return

        # Deduplicate: group futures by identical prompts
        query_groups = defaultdict(list)
        for token_ids, future in queries:
            query_groups[tuple(token_ids)].append(future)

        unique_token_ids = list(query_groups.keys())

        self.logprobs_capture.clear()

        prompts = [
            TokensPrompt(prompt_token_ids=list(token_ids))
            for token_ids in unique_token_ids
        ]

        try:
            self.llm_engine.generate(
                prompts=prompts,
                sampling_params=SamplingParams(**self.default_params),
                lora_request=self.lora_request,
                use_tqdm=False,
            )

            all_logprobs = self.logprobs_capture.get_all_logprobs()
            assert all_logprobs is not None, "Logprobs should be captured"
            assert all_logprobs.shape[0] == len(unique_token_ids), (
                f"Expected {len(unique_token_ids)} logprobs, got {all_logprobs.shape[0]}"
            )

            for i, key in enumerate(unique_token_ids):
                logprobs = all_logprobs[i]
                futures = query_groups[key]
                if len(futures) == 1:
                    futures[0].set_result(logprobs)
                else:
                    for future in futures:
                        future.set_result(logprobs.clone())
        except Exception as exc:
            for futures in query_groups.values():
                for future in futures:
                    if not future.done():
                        future.set_exception(exc)

    def reset_async_queries(self):
        """Clear any pending queries from the queue.

        Use this method when an exception prevented an inference algorithm
        from executing to completion.
        """
        self.queries = []
        if self.timer:
            self.timer.cancel()
            self.timer = None

        self.sample_queries = []
        if self.sample_timer:
            self.sample_timer.cancel()
            self.sample_timer = None

    def next_token_logprobs_sync(self, token_ids):
        """Request log probabilities of next token synchronously.

        Does not support auto-batching. For batched sync calls, use
        ``batch_next_token_logprobs_sync`` instead.

        Args:
            token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        key = tuple(token_ids)

        if self.cache is not None and key in self.cache:
            return self.cache[key]

        if self.logprobs_capture is None:
            raise RuntimeError("Cannot use model after cleanup() has been called")

        self.logprobs_capture.clear()

        self.llm_engine.generate(
            prompts=TokensPrompt(prompt_token_ids=list(token_ids)),
            sampling_params=SamplingParams(**self.default_params),
            lora_request=self.lora_request,
            use_tqdm=False,
        )

        result = self.logprobs_capture.get_logprobs(batch_index=0)
        assert result is not None, "Logprobs should be captured by global processor"

        if self.cache is not None:
            self.cache[key] = result

        return result

    def batch_next_token_logprobs_sync(self, token_ids_list):
        """
        Request log probabilities of next tokens in a batch synchronously.

        Args:
            token_ids_list (list[list[int]]): A list of token ID lists, each representing a prompt to the language model.

        Returns:
            (torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list.

        Note:
            This method does not consult the output cache (unlike the async batch path,
            which delegates to the cached ``next_token_logprobs``). Every prompt is
            re-evaluated.
        """
        if self.logprobs_capture is None:
            raise RuntimeError("Cannot use model after cleanup() has been called")
        # Clear any stale captured logprobs
        self.logprobs_capture.clear()

        # Create prompts for batch
        prompts = [
            TokensPrompt(prompt_token_ids=list(token_ids))
            for token_ids in token_ids_list
        ]

        # Generate one token for each prompt
        self.llm_engine.generate(
            prompts=prompts,
            sampling_params=SamplingParams(**self.default_params),
            lora_request=self.lora_request,
            use_tqdm=False,
        )

        # Get all captured logprobs at once (optimized - single clone)
        all_logprobs = self.logprobs_capture.get_all_logprobs()
        assert all_logprobs is not None, "Logprobs should be captured"
        assert all_logprobs.shape[0] == len(token_ids_list), (
            f"Expected {len(token_ids_list)} logprobs, got {all_logprobs.shape[0]}"
        )

        return all_logprobs

    def clear_cache(self):
        """Clear output cache."""
        if self.cache:
            self.cache.clear()

    def cleanup(self):
        """Explicitly clean up GPU resources. Call this when done with the model."""
        self._cleanup_engine()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.cleanup()
        return False

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        self.cleanup()
        return False

    def __del__(self):
        """Clean up resources on deletion."""
        self._cleanup_engine()

    def _cleanup_engine(self):
        """Clean up the vLLM engine and associated resources.

        This is invoked from both :meth:`cleanup` (explicit, during normal
        program flow) and :meth:`__del__` (implicit, possibly at
        interpreter shutdown). The narrow exception classes below cover
        the races and idempotency issues we know about:

        * ``ImportError`` / ``AttributeError`` arise when ``__del__`` runs
          after ``sys.meta_path`` is already torn down during interpreter
          shutdown.
        * ``AssertionError`` is raised by vLLM's
          ``destroy_distributed_environment`` if it's called twice.
        * ``RuntimeError`` can surface from CUDA when the driver is
          already being torn down.

        Anything else is re-raised so real bugs are not swallowed.
        """
        try:
            # ``import gc`` can itself raise ImportError when ``__del__`` is
            # invoked after ``sys.meta_path`` has been torn down at
            # interpreter shutdown, so it lives inside the try block.
            import gc

            # Clear our references
            if hasattr(self, "logprobs_capture"):
                if self.logprobs_capture is not None:
                    self.logprobs_capture.clear()
                self.logprobs_capture = None

            # Delete the engine to free GPU memory
            if hasattr(self, "llm_engine") and self.llm_engine is not None:
                del self.llm_engine
                self.llm_engine = None

            # Force garbage collection
            gc.collect()

            # Clear CUDA cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()

            # Clean up distributed state
            destroy_model_parallel()
            destroy_distributed_environment()
        except (
            ImportError,
            AttributeError,
            AssertionError,
            RuntimeError,
        ) as e:
            # Best-effort log; during interpreter shutdown logging itself
            # may already be torn down, in which case silently drop.
            with contextlib.suppress(Exception):
                logging.getLogger(__name__).debug(
                    "AsyncVirtualLM cleanup raised %s: %s",
                    type(e).__name__,
                    e,
                )

    def _add_sample_query(self, prompt_token_ids, sampling_params, future):
        """Enqueue a ``sample()`` request; mirrors ``_add_query`` for the logprobs path."""
        self.sample_queries.append((prompt_token_ids, sampling_params, future))
        if len(self.sample_queries) >= self.batch_size:
            if self.sample_timer:
                self.sample_timer.cancel()
                self.sample_timer = None
            self._batch_sample_evaluate()
        elif self.sample_timer is None:
            self.sample_timer = asyncio.get_running_loop().call_later(
                self.timeout, self._batch_sample_evaluate
            )

    def _batch_sample_evaluate(self):
        """Dispatch queued ``sample()`` requests in one batched ``generate()`` call."""
        queries, self.sample_queries = self.sample_queries, []
        if not queries:
            return
        if self.sample_timer:
            self.sample_timer.cancel()
            self.sample_timer = None
        if self.logprobs_capture is None:
            exc = RuntimeError("Cannot use model after cleanup() has been called")
            for _, _, future in queries:
                future.set_exception(exc)
            return
        try:
            outputs = self.llm_engine.generate(
                prompts=[TokensPrompt(prompt_token_ids=t) for t, _, _ in queries],
                sampling_params=[sp for _, sp, _ in queries],
                lora_request=self.lora_request,
                use_tqdm=False,
            )
            assert len(outputs) == len(queries)
            for output, (_, _, future) in zip(outputs, queries):
                future.set_result(list(output.outputs[0].token_ids))
        except Exception as exc:
            for _, _, future in queries:
                if not future.done():
                    future.set_exception(exc)

    async def sample(
        self,
        prompt_token_ids,
        max_tokens,
        eos_token_ids,
        temperature=1.0,
        seed=None,
    ):
        """Sample from the language model.

        Concurrent calls are auto-batched into a single ``LLM.generate()``
        so vLLM continuous-batches the decode steps. Use with ``await``.

        Args:
            prompt_token_ids (list[int]): The token IDs of the prompt.
            eos_token_ids (list[int]): The token IDs of the end-of-sequence tokens.
            temperature (float, optional): The temperature to use to rescale the logits. Defaults to 1.0.
            max_tokens (int): The maximum number of tokens to generate.
            seed (int, optional): The seed for the random number generator. Defaults to None.

        Returns:
            (list[int]): The sampled token IDs.
        """
        future = asyncio.get_running_loop().create_future()
        self._add_sample_query(
            list(prompt_token_ids),
            SamplingParams(
                n=1,
                max_tokens=max_tokens,
                temperature=temperature,
                seed=seed,
                stop=[self.byte_vocab[i].decode() for i in eos_token_ids],
            ),
            future,
        )
        token_ids = await future
        if token_ids and token_ids[-1] in eos_token_ids:
            token_ids = token_ids[:-1]
        return token_ids

__init__(llm_engine, logprobs_capture, cache_size=0, cache_opts=None, batch_size=20, timeout=0.02)

Initialize an AsyncVirtualLM instance.

Parameters:

Name Type Description Default
llm_engine LLM

The vLLM engine instance.

required
logprobs_capture GlobalLogprobsCapture

The global logprobs capture processor.

required
cache_size int

Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.

0
cache_opts dict

Additional options to pass to the OutputCache constructor. Defaults to None (no extra options).

None
batch_size int

Maximum queries to process in one batch during auto-batching. Defaults to 20.

20
timeout float

Seconds to wait after the first queued query before processing the current batch. The batch also fires immediately when batch_size is reached. Defaults to 0.02.

0.02
Note

The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine. batch_next_token_logprobs_sync bypasses this cache and always re-evaluates; the other three logprobs methods consult it.

Source code in genlm/backend/llm/vllm.py
def __init__(
    self,
    llm_engine,
    logprobs_capture,
    cache_size=0,
    cache_opts=None,
    batch_size=20,
    timeout=0.02,
):
    """Initialize an `AsyncVirtualLM` instance.

    Args:
        llm_engine (LLM): The vLLM engine instance.
        logprobs_capture (GlobalLogprobsCapture): The global logprobs capture processor.
        cache_size (int, optional): Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.
        cache_opts (dict, optional): Additional options to pass to the [`OutputCache`][genlm.backend.cache.OutputCache] constructor. Defaults to None (no extra options).
        batch_size (int, optional): Maximum queries to process in one batch during auto-batching. Defaults to 20.
        timeout (float, optional): Seconds to wait after the first queued query before processing the current batch. The batch also fires immediately when ``batch_size`` is reached. Defaults to 0.02.

    Note:
        The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine.
        ``batch_next_token_logprobs_sync`` bypasses this cache and always re-evaluates; the other three logprobs methods consult it.
    """
    self.llm_engine = llm_engine
    self.logprobs_capture = logprobs_capture
    self.tokenizer = llm_engine.get_tokenizer()
    self.cache = (
        OutputCache(maxsize=cache_size, **(cache_opts or {}))
        if cache_size > 0
        else None
    )
    self.lora_request = None
    self.lora_name_to_ids = {}

    self.queries = []
    self.batch_size = batch_size
    self.timeout = timeout
    self.timer = None

    self.sample_queries = []
    self.sample_timer = None

    super().__init__(tokenizer=self.tokenizer)

from_name(model_name, engine_opts=None, **kwargs) classmethod

Create a AsyncVirtualLM instance from a model name.

Parameters:

Name Type Description Default
model_name str

Name of the model to load.

required
engine_opts dict

Additional options to pass to the LLM engine.

None
**kwargs

Additional arguments passed to AsyncVirtualLM constructor.

{}

Returns:

Type Description
AsyncVirtualLM

An AsyncVirtualLM instance.

Source code in genlm/backend/llm/vllm.py
@classmethod
def from_name(cls, model_name, engine_opts=None, **kwargs):
    """Create a `AsyncVirtualLM` instance from a model name.

    Args:
        model_name (str): Name of the model to load.
        engine_opts (dict): Additional options to pass to the `LLM` engine.
        **kwargs: Additional arguments passed to `AsyncVirtualLM` constructor.

    Returns:
        (AsyncVirtualLM): An `AsyncVirtualLM` instance.
    """
    if not HAS_VLLM:
        raise ImportError(  # pragma: no cover
            "vLLM not available. Install vLLM or use AsyncTransformer instead."
        )

    engine_opts = {
        "enable_prefix_caching": True,
        "disable_log_stats": True,
        "gpu_memory_utilization": 0.9,
        **(engine_opts or {}),
    }

    llm = LLM(model=model_name, tokenizer=model_name, **engine_opts)

    logprobs_capture = GlobalLogprobsCapture()
    model_runner = cls._get_model_runner(llm)
    model_runner.input_batch.logitsprocs.argmax_invariant.append(
        logprobs_capture
    )

    return cls(llm, logprobs_capture, **kwargs)

underlying_model property

Access the underlying model for advanced use cases.

clear_lora()

Disable any active LoRA adapter for the vLLM engine.

Source code in genlm/backend/llm/vllm.py
def clear_lora(self):
    """
    Disable any active LoRA adapter for the vLLM engine.
    """
    self.lora_request = None

add_new_lora(lora_path, lora_name='lora_1')

Load a LoRA adapter into the base model by creating a unique id for it.

Parameters:

Name Type Description Default
lora_path str

Path to the adapter weights directory or identifier in HuggingFace's model hub.

required
lora_name str

Name to assign to the loaded adapter.

'lora_1'
Notes

This does not activate the adapter immediately. Call set_lora() to enable the adapter.

Source code in genlm/backend/llm/vllm.py
def add_new_lora(self, lora_path, lora_name="lora_1"):
    """Load a LoRA adapter into the base model by creating a unique id for it.

    Args:
        lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
        lora_name (str): Name to assign to the loaded adapter.

    Notes:
        This does not activate the adapter immediately. Call `set_lora()` to enable the adapter.
    """
    self.lora_name_to_ids[lora_name] = self.hash_to_int(lora_name)

hash_to_int(value)

Generates a deterministic unique id for a LoRA adapter from its name.

Parameters:

Name Type Description Default
value str

The name of the LoRA adapter to hash.

required

Returns:

Type Description

An integer ID corresponding to the LoRA adapter, in the range [1, 2^31 - 1].

Source code in genlm/backend/llm/vllm.py
def hash_to_int(self, value):
    """Generates a deterministic unique id for a LoRA adapter from its name.

    Args:
        value (str): The name of the LoRA adapter to hash.

    Returns:
        An integer ID corresponding to the LoRA adapter, in the range [1, 2^31 - 1].
    """
    hash_bytes = hashlib.shake_128(value.encode("utf-8")).digest(4)
    return (int.from_bytes(hash_bytes, "big") % (2**31 - 2)) + 1

set_lora(lora_path, lora_name='lora_1')

Configure a LoRA adapter request for the vLLM engine.

Parameters:

Name Type Description Default
lora_path str

Path to the adapter weights directory or identifier in HuggingFace's model hub.

required
lora_name str

Identifier name to associate with this LoRA adapter within vLLM.

'lora_1'
lora_id int

Globally unique ID for the adapter.

required
Source code in genlm/backend/llm/vllm.py
def set_lora(self, lora_path, lora_name="lora_1"):
    """Configure a LoRA adapter request for the vLLM engine.

    Args:
        lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
        lora_name (str): Identifier name to associate with this LoRA adapter within vLLM.
        lora_id (int): Globally unique ID for the adapter.
    """
    if lora_name not in self.lora_name_to_ids:
        raise ValueError(
            f"A LoRA adapter named '{lora_name}' has not been loaded yet. Please call add_new_lora() first to load and name your LoRA adapters."
        )
    self.lora_request = LoRARequest(
        lora_name, self.lora_name_to_ids[lora_name], lora_path
    )

next_token_logprobs(token_ids) async

Request log probabilities of next token asynchronously with auto-batching.

Concurrent calls to this method are automatically batched into a single LLM.generate() call for efficiency. Use with await.

Parameters:

Name Type Description Default
token_ids list[int]

A list of token IDs, representing a prompt to the language model.

required

Returns:

Name Type Description
result Tensor

Normalized log probability tensor.

Source code in genlm/backend/llm/vllm.py
async def next_token_logprobs(self, token_ids):
    """Request log probabilities of next token asynchronously with auto-batching.

    Concurrent calls to this method are automatically batched into a single
    ``LLM.generate()`` call for efficiency. Use with ``await``.

    Args:
        token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

    Returns:
        result (torch.Tensor): Normalized log probability tensor.
    """
    key = tuple(token_ids)

    if self.cache is not None and key in self.cache:
        return self.cache[key]

    future = asyncio.get_running_loop().create_future()
    self._add_query(token_ids, future)
    result = await future

    if self.cache is not None:
        self.cache[key] = result

    return result

reset_async_queries()

Clear any pending queries from the queue.

Use this method when an exception prevented an inference algorithm from executing to completion.

Source code in genlm/backend/llm/vllm.py
def reset_async_queries(self):
    """Clear any pending queries from the queue.

    Use this method when an exception prevented an inference algorithm
    from executing to completion.
    """
    self.queries = []
    if self.timer:
        self.timer.cancel()
        self.timer = None

    self.sample_queries = []
    if self.sample_timer:
        self.sample_timer.cancel()
        self.sample_timer = None

next_token_logprobs_sync(token_ids)

Request log probabilities of next token synchronously.

Does not support auto-batching. For batched sync calls, use batch_next_token_logprobs_sync instead.

Parameters:

Name Type Description Default
token_ids list[int]

A list of token IDs, representing a prompt to the language model.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm/backend/llm/vllm.py
def next_token_logprobs_sync(self, token_ids):
    """Request log probabilities of next token synchronously.

    Does not support auto-batching. For batched sync calls, use
    ``batch_next_token_logprobs_sync`` instead.

    Args:
        token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    key = tuple(token_ids)

    if self.cache is not None and key in self.cache:
        return self.cache[key]

    if self.logprobs_capture is None:
        raise RuntimeError("Cannot use model after cleanup() has been called")

    self.logprobs_capture.clear()

    self.llm_engine.generate(
        prompts=TokensPrompt(prompt_token_ids=list(token_ids)),
        sampling_params=SamplingParams(**self.default_params),
        lora_request=self.lora_request,
        use_tqdm=False,
    )

    result = self.logprobs_capture.get_logprobs(batch_index=0)
    assert result is not None, "Logprobs should be captured by global processor"

    if self.cache is not None:
        self.cache[key] = result

    return result

batch_next_token_logprobs_sync(token_ids_list)

Request log probabilities of next tokens in a batch synchronously.

Parameters:

Name Type Description Default
token_ids_list list[list[int]]

A list of token ID lists, each representing a prompt to the language model.

required

Returns:

Type Description
Tensor

A tensor of normalized log probability tensors, one for each prompt in the input list.

Note

This method does not consult the output cache (unlike the async batch path, which delegates to the cached next_token_logprobs). Every prompt is re-evaluated.

Source code in genlm/backend/llm/vllm.py
def batch_next_token_logprobs_sync(self, token_ids_list):
    """
    Request log probabilities of next tokens in a batch synchronously.

    Args:
        token_ids_list (list[list[int]]): A list of token ID lists, each representing a prompt to the language model.

    Returns:
        (torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list.

    Note:
        This method does not consult the output cache (unlike the async batch path,
        which delegates to the cached ``next_token_logprobs``). Every prompt is
        re-evaluated.
    """
    if self.logprobs_capture is None:
        raise RuntimeError("Cannot use model after cleanup() has been called")
    # Clear any stale captured logprobs
    self.logprobs_capture.clear()

    # Create prompts for batch
    prompts = [
        TokensPrompt(prompt_token_ids=list(token_ids))
        for token_ids in token_ids_list
    ]

    # Generate one token for each prompt
    self.llm_engine.generate(
        prompts=prompts,
        sampling_params=SamplingParams(**self.default_params),
        lora_request=self.lora_request,
        use_tqdm=False,
    )

    # Get all captured logprobs at once (optimized - single clone)
    all_logprobs = self.logprobs_capture.get_all_logprobs()
    assert all_logprobs is not None, "Logprobs should be captured"
    assert all_logprobs.shape[0] == len(token_ids_list), (
        f"Expected {len(token_ids_list)} logprobs, got {all_logprobs.shape[0]}"
    )

    return all_logprobs

clear_cache()

Clear output cache.

Source code in genlm/backend/llm/vllm.py
def clear_cache(self):
    """Clear output cache."""
    if self.cache:
        self.cache.clear()

cleanup()

Explicitly clean up GPU resources. Call this when done with the model.

Source code in genlm/backend/llm/vllm.py
def cleanup(self):
    """Explicitly clean up GPU resources. Call this when done with the model."""
    self._cleanup_engine()

__del__()

Clean up resources on deletion.

Source code in genlm/backend/llm/vllm.py
def __del__(self):
    """Clean up resources on deletion."""
    self._cleanup_engine()

sample(prompt_token_ids, max_tokens, eos_token_ids, temperature=1.0, seed=None) async

Sample from the language model.

Concurrent calls are auto-batched into a single LLM.generate() so vLLM continuous-batches the decode steps. Use with await.

Parameters:

Name Type Description Default
prompt_token_ids list[int]

The token IDs of the prompt.

required
eos_token_ids list[int]

The token IDs of the end-of-sequence tokens.

required
temperature float

The temperature to use to rescale the logits. Defaults to 1.0.

1.0
max_tokens int

The maximum number of tokens to generate.

required
seed int

The seed for the random number generator. Defaults to None.

None

Returns:

Type Description
list[int]

The sampled token IDs.

Source code in genlm/backend/llm/vllm.py
async def sample(
    self,
    prompt_token_ids,
    max_tokens,
    eos_token_ids,
    temperature=1.0,
    seed=None,
):
    """Sample from the language model.

    Concurrent calls are auto-batched into a single ``LLM.generate()``
    so vLLM continuous-batches the decode steps. Use with ``await``.

    Args:
        prompt_token_ids (list[int]): The token IDs of the prompt.
        eos_token_ids (list[int]): The token IDs of the end-of-sequence tokens.
        temperature (float, optional): The temperature to use to rescale the logits. Defaults to 1.0.
        max_tokens (int): The maximum number of tokens to generate.
        seed (int, optional): The seed for the random number generator. Defaults to None.

    Returns:
        (list[int]): The sampled token IDs.
    """
    future = asyncio.get_running_loop().create_future()
    self._add_sample_query(
        list(prompt_token_ids),
        SamplingParams(
            n=1,
            max_tokens=max_tokens,
            temperature=temperature,
            seed=seed,
            stop=[self.byte_vocab[i].decode() for i in eos_token_ids],
        ),
        future,
    )
    token_ids = await future
    if token_ids and token_ids[-1] in eos_token_ids:
        token_ids = token_ids[:-1]
    return token_ids

AsyncTransformer

Bases: AsyncLM

Asynchronous wrapper around a HuggingFace causal language model with caching support.

This class provides an asynchronous interface to HuggingFace language models with automatic batching and caching (output and KV) for improved efficiency.

Source code in genlm/backend/llm/hf.py
class AsyncTransformer(AsyncLM):
    """Asynchronous wrapper around a HuggingFace causal language model with caching support.

    This class provides an asynchronous interface to HuggingFace language models with automatic batching
    and caching (output and KV) for improved efficiency.
    """

    @classmethod
    def from_name(cls, model_id, bitsandbytes_opts=None, hf_opts=None, **kwargs):
        """Create an AsyncTransformer instance from a pretrained HuggingFace model.

        Args:
            model_id (str): Model identifier in HuggingFace's model hub.
            bitsandbytes_opts (dict, optional): Additional configuration options for bitsandbytes quantization.
                Defaults to None.
            hf_opts (dict, optional): Additional configuration options for loading the HuggingFace model.
                Defaults to None.
            **kwargs: Additional arguments passed to the `AsyncTransformer` constructor

        Returns:
            (AsyncTransformer): An initialized `AsyncTransformer` instance.
        """
        if bitsandbytes_opts:
            bnb_config = BitsAndBytesConfig(**bitsandbytes_opts)
        else:
            bnb_config = None

        _hf_opts = {
            "device_map": "auto",
            "torch_dtype": "auto",
        }
        if hf_opts:
            _hf_opts.update(hf_opts)

        tok = AutoTokenizer.from_pretrained(model_id)
        mod = AutoModelForCausalLM.from_pretrained(
            model_id, quantization_config=bnb_config, **_hf_opts
        )

        return cls(mod, tok, **kwargs)

    @torch.no_grad()
    def __init__(self, hf_model, hf_tokenizer, batch_size=20, timeout=0.02):
        """Initialize an AsyncTransformer instance.

        Args:
            hf_model: A HuggingFace CausalLM model instance.
            hf_tokenizer: A HuggingFace Tokenizer.
            batch_size (int, optional): Maximum queries to process in one batch during auto-batching.
                Defaults to 20.
            timeout (float, optional): Seconds to wait since last query before processing current batch.
                Defaults to 0.02.
        """
        self.model = hf_model
        self.tokenizer = hf_tokenizer
        self.device = hf_model.device
        self.cache = TokenTrie()

        # Queries to be batched. Each query is a sequence of tokens,
        # and a Future to be called when the query is resolved.
        self.queries = []
        self.batch_size = batch_size
        self.timeout = timeout
        self.timer = None

        self.model.eval()

        super().__init__(tokenizer=self.tokenizer)

    def clear_cache(self):
        """Clear the cache of log probabilities and key/value pairs."""
        self.cache = TokenTrie()

    def clear_kv_cache(self):
        """Clear any key and value vectors from the cache."""
        self.cache.clear_kv_cache()

    def reset_async_queries(self):
        """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing
        to completion."""
        self.queries = []

    @torch.no_grad()
    def cache_kv(self, prompt_tokens):
        """Cache the key and value vectors for a prompt. Future queries that have this prompt as a prefix will only run the LLM on new tokens.

        Args:
            prompt_tokens (list[int]): token ids for the prompt to cache.
        """
        result = self.model(torch.tensor([prompt_tokens]).to(self.device))
        node = self.cache.extend_cache(0, prompt_tokens, result.logits[0], 0)
        node.past_key_values = result.past_key_values

    def add_new_lora(self, lora_path, lora_name="lora_1"):
        """Load a LoRA adapter into the base model.

        Args:
            lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
            lora_name (str): Name to assign to the loaded adapter.

        Notes:
            This does not activate the adapter immediately. Call `set_lora()` to enable the adapter.
        """
        self.model.load_adapter(lora_path, lora_name)

    def set_lora(self, lora_path=None, lora_name="lora_1"):
        """Activate a previously loaded LoRA adapter.

        Args:
            lora_name (str): Name of the LoRA adapter to activate.

        """
        if lora_name not in list(self.model.peft_config.keys()):
            raise ValueError(
                f"A LoRA adapter named '{lora_name}' has not been loaded yet. Please call add_new_lora() first to load and name your LoRA adapters."
            )

        self.clear_kv_cache()
        self.clear_cache()
        self.model.set_adapter(lora_name)

    def clear_lora(self):
        """
        Deactivate all LoRA adapters.
        """
        self.clear_kv_cache()
        self.clear_cache()
        self.model.set_adapter([])

    @torch.no_grad()
    def batch_evaluate_queries(self):
        """
        Process a batch of queued language model queries.

        This method is called internally when the `batch_size` has been met or the `timeout` has expired.
        """

        queries, self.queries = self.queries, []
        if len(queries) == 0:
            return

        query_groups = defaultdict(list)
        for query in queries:
            key = tuple(query.prompt)  # XXX: cache based on past_len too?
            query_groups[key].append(query)

        # Use one representative query from each group
        unique_queries = [group[0] for group in query_groups.values()]

        past_example = next((q.past for q in unique_queries if q.past), False)
        max_past_length = max(q.past_len for q in unique_queries)
        max_query_length = max(len(q.prompt) for q in unique_queries)

        padding_token_id = (
            self.tokenizer.pad_token_id
            if self.tokenizer.pad_token_id is not None
            else 0
        )

        input_ids = torch.tensor(
            [
                q.prompt_padded(padding_token_id, max_query_length)
                for q in unique_queries
            ]
        ).to(self.device)
        attn_masks = torch.tensor(
            [
                q.attention_mask(max_past_length, max_query_length)
                for q in unique_queries
            ]
        ).to(self.device)
        posn_ids = torch.tensor(
            [q.position_ids(max_past_length, max_query_length) for q in unique_queries]
        ).to(self.device)
        if past_example:
            pasts = [
                [
                    torch.cat(
                        (
                            *(
                                q.past_padded(
                                    layer,
                                    j,
                                    max_past_length,
                                    past_example[0][0].dtype,
                                    self.device,
                                    past_example[0][0].shape,
                                )
                                for q in unique_queries
                            ),
                        ),
                        dim=0,
                    )
                    for j in range(2)
                ]
                for layer in range(len(past_example))
            ]
        else:
            pasts = None

        pasts = DynamicCache.from_legacy_cache(pasts)

        results = self.model(
            input_ids,
            attention_mask=attn_masks,
            position_ids=posn_ids,
            past_key_values=pasts,
            use_cache=pasts is not None,
        )

        assert len(results.logits) == len(unique_queries)

        for i, q in enumerate(unique_queries):
            result = results.logits[i]
            for dup_query in query_groups[tuple(q.prompt)]:
                dup_query.future.set_result(result)

    @torch.no_grad()
    def add_query(self, query, future, past):
        """Add a query to be evaluated in the next batch.

        This method is called internally when a `next_token_logprobs` request is made.

        Args:
            query (list[int]): Token IDs representing the query prompt
            future (asyncio.Future): Future to store the result in
            past (list[tuple[torch.Tensor]]|None): Past key/value states from previous evaluation,
                or None if this is a new query
        """
        self.queries.append(Query(query, future, past))

        if self.timer:
            self.timer.cancel()
            self.timer = None
        if len(self.queries) >= self.batch_size:
            self.batch_evaluate_queries()
        else:
            self.timer = asyncio.get_running_loop().call_later(
                self.timeout, lambda: self.batch_evaluate_queries()
            )

    def walk_cache(self, token_ids):
        """Walk the cache tree to find the deepest node matching a sequence of tokens.

        Args:
            token_ids (list[int]): Sequence of token IDs to follow in the cache tree

        Returns:
            tuple:
                - CacheNode: The deepest node in the cache tree that matches the token sequence
                - int: Number of tokens matched from the start of token_ids
                - list[tuple[torch.Tensor]]|None: Past key/value states from the deepest cached node,
                    or None if no cached states were found
                - int: Base index indicating where the past states start in token_ids
        """
        # Walk while tokens can be found
        node = self.cache
        next_token_index = 0

        past = None
        base = 0
        while next_token_index < len(token_ids):
            if node.past_key_values is not None:
                past = node.past_key_values
                base = next_token_index
            if node.has_token(token_ids[next_token_index]):
                node = node.get_token(token_ids[next_token_index])
                next_token_index += 1
            else:
                break

        return node, next_token_index, past, base

    @torch.no_grad()
    async def next_token_logprobs(self, token_ids):
        """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`.

        Args:
            token_ids (list[int]): a list of token ids, representing a prompt to the language model.

        Returns:
            logprobs (torch.Tensor): a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        node, next_token_index, past, base = self.walk_cache(token_ids)

        # If we processed all tokens, then we're done.
        if next_token_index == len(token_ids):
            return node.logprobs

        # Create a future with the prompt
        future = asyncio.get_running_loop().create_future()
        self.add_query(token_ids[base:], future, past)
        logits = await future

        # Create new nodes
        node = node.extend_cache(next_token_index, token_ids, logits, base)

        return node.logprobs

    @torch.no_grad()
    def next_token_logprobs_sync(self, token_ids):
        """Request log probabilities of next token. Not asynchronous, and does not support auto-batching.

        Args:
            token_ids (list[int]): a list of token ids, representing a prompt to the language model.

        Returns:
            logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        # Walk while tokens can be found
        node, next_token_index, past, base = self.walk_cache(token_ids)

        if next_token_index == len(token_ids):
            return node.logprobs

        logits = self.model(
            torch.tensor([token_ids[base:]]).to(self.device),
            past_key_values=node.past_key_values,
            use_cache=node.past_key_values is not None,
        ).logits[0]

        node = node.extend_cache(next_token_index, token_ids, logits, base)

        return node.logprobs

    def next_token_logprobs_uncached(self, token_ids):
        """Request log probabilities of next token. No KV or output caching, and does not support auto-batching.

        Args:
            token_ids (list[int]): a list of token ids, representing a prompt to the language model.

        Returns:
            logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        with torch.no_grad():
            logits = self.model(
                torch.tensor([token_ids]).to(self.device),
                past_key_values=None,
                use_cache=False,
            ).logits[0]
            return torch.log_softmax(logits[-1], dim=0)

from_name(model_id, bitsandbytes_opts=None, hf_opts=None, **kwargs) classmethod

Create an AsyncTransformer instance from a pretrained HuggingFace model.

Parameters:

Name Type Description Default
model_id str

Model identifier in HuggingFace's model hub.

required
bitsandbytes_opts dict

Additional configuration options for bitsandbytes quantization. Defaults to None.

None
hf_opts dict

Additional configuration options for loading the HuggingFace model. Defaults to None.

None
**kwargs

Additional arguments passed to the AsyncTransformer constructor

{}

Returns:

Type Description
AsyncTransformer

An initialized AsyncTransformer instance.

Source code in genlm/backend/llm/hf.py
@classmethod
def from_name(cls, model_id, bitsandbytes_opts=None, hf_opts=None, **kwargs):
    """Create an AsyncTransformer instance from a pretrained HuggingFace model.

    Args:
        model_id (str): Model identifier in HuggingFace's model hub.
        bitsandbytes_opts (dict, optional): Additional configuration options for bitsandbytes quantization.
            Defaults to None.
        hf_opts (dict, optional): Additional configuration options for loading the HuggingFace model.
            Defaults to None.
        **kwargs: Additional arguments passed to the `AsyncTransformer` constructor

    Returns:
        (AsyncTransformer): An initialized `AsyncTransformer` instance.
    """
    if bitsandbytes_opts:
        bnb_config = BitsAndBytesConfig(**bitsandbytes_opts)
    else:
        bnb_config = None

    _hf_opts = {
        "device_map": "auto",
        "torch_dtype": "auto",
    }
    if hf_opts:
        _hf_opts.update(hf_opts)

    tok = AutoTokenizer.from_pretrained(model_id)
    mod = AutoModelForCausalLM.from_pretrained(
        model_id, quantization_config=bnb_config, **_hf_opts
    )

    return cls(mod, tok, **kwargs)

__init__(hf_model, hf_tokenizer, batch_size=20, timeout=0.02)

Initialize an AsyncTransformer instance.

Parameters:

Name Type Description Default
hf_model

A HuggingFace CausalLM model instance.

required
hf_tokenizer

A HuggingFace Tokenizer.

required
batch_size int

Maximum queries to process in one batch during auto-batching. Defaults to 20.

20
timeout float

Seconds to wait since last query before processing current batch. Defaults to 0.02.

0.02
Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def __init__(self, hf_model, hf_tokenizer, batch_size=20, timeout=0.02):
    """Initialize an AsyncTransformer instance.

    Args:
        hf_model: A HuggingFace CausalLM model instance.
        hf_tokenizer: A HuggingFace Tokenizer.
        batch_size (int, optional): Maximum queries to process in one batch during auto-batching.
            Defaults to 20.
        timeout (float, optional): Seconds to wait since last query before processing current batch.
            Defaults to 0.02.
    """
    self.model = hf_model
    self.tokenizer = hf_tokenizer
    self.device = hf_model.device
    self.cache = TokenTrie()

    # Queries to be batched. Each query is a sequence of tokens,
    # and a Future to be called when the query is resolved.
    self.queries = []
    self.batch_size = batch_size
    self.timeout = timeout
    self.timer = None

    self.model.eval()

    super().__init__(tokenizer=self.tokenizer)

clear_cache()

Clear the cache of log probabilities and key/value pairs.

Source code in genlm/backend/llm/hf.py
def clear_cache(self):
    """Clear the cache of log probabilities and key/value pairs."""
    self.cache = TokenTrie()

clear_kv_cache()

Clear any key and value vectors from the cache.

Source code in genlm/backend/llm/hf.py
def clear_kv_cache(self):
    """Clear any key and value vectors from the cache."""
    self.cache.clear_kv_cache()

reset_async_queries()

Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing to completion.

Source code in genlm/backend/llm/hf.py
def reset_async_queries(self):
    """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing
    to completion."""
    self.queries = []

cache_kv(prompt_tokens)

Cache the key and value vectors for a prompt. Future queries that have this prompt as a prefix will only run the LLM on new tokens.

Parameters:

Name Type Description Default
prompt_tokens list[int]

token ids for the prompt to cache.

required
Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def cache_kv(self, prompt_tokens):
    """Cache the key and value vectors for a prompt. Future queries that have this prompt as a prefix will only run the LLM on new tokens.

    Args:
        prompt_tokens (list[int]): token ids for the prompt to cache.
    """
    result = self.model(torch.tensor([prompt_tokens]).to(self.device))
    node = self.cache.extend_cache(0, prompt_tokens, result.logits[0], 0)
    node.past_key_values = result.past_key_values

add_new_lora(lora_path, lora_name='lora_1')

Load a LoRA adapter into the base model.

Parameters:

Name Type Description Default
lora_path str

Path to the adapter weights directory or identifier in HuggingFace's model hub.

required
lora_name str

Name to assign to the loaded adapter.

'lora_1'
Notes

This does not activate the adapter immediately. Call set_lora() to enable the adapter.

Source code in genlm/backend/llm/hf.py
def add_new_lora(self, lora_path, lora_name="lora_1"):
    """Load a LoRA adapter into the base model.

    Args:
        lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
        lora_name (str): Name to assign to the loaded adapter.

    Notes:
        This does not activate the adapter immediately. Call `set_lora()` to enable the adapter.
    """
    self.model.load_adapter(lora_path, lora_name)

set_lora(lora_path=None, lora_name='lora_1')

Activate a previously loaded LoRA adapter.

Parameters:

Name Type Description Default
lora_name str

Name of the LoRA adapter to activate.

'lora_1'
Source code in genlm/backend/llm/hf.py
def set_lora(self, lora_path=None, lora_name="lora_1"):
    """Activate a previously loaded LoRA adapter.

    Args:
        lora_name (str): Name of the LoRA adapter to activate.

    """
    if lora_name not in list(self.model.peft_config.keys()):
        raise ValueError(
            f"A LoRA adapter named '{lora_name}' has not been loaded yet. Please call add_new_lora() first to load and name your LoRA adapters."
        )

    self.clear_kv_cache()
    self.clear_cache()
    self.model.set_adapter(lora_name)

clear_lora()

Deactivate all LoRA adapters.

Source code in genlm/backend/llm/hf.py
def clear_lora(self):
    """
    Deactivate all LoRA adapters.
    """
    self.clear_kv_cache()
    self.clear_cache()
    self.model.set_adapter([])

batch_evaluate_queries()

Process a batch of queued language model queries.

This method is called internally when the batch_size has been met or the timeout has expired.

Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def batch_evaluate_queries(self):
    """
    Process a batch of queued language model queries.

    This method is called internally when the `batch_size` has been met or the `timeout` has expired.
    """

    queries, self.queries = self.queries, []
    if len(queries) == 0:
        return

    query_groups = defaultdict(list)
    for query in queries:
        key = tuple(query.prompt)  # XXX: cache based on past_len too?
        query_groups[key].append(query)

    # Use one representative query from each group
    unique_queries = [group[0] for group in query_groups.values()]

    past_example = next((q.past for q in unique_queries if q.past), False)
    max_past_length = max(q.past_len for q in unique_queries)
    max_query_length = max(len(q.prompt) for q in unique_queries)

    padding_token_id = (
        self.tokenizer.pad_token_id
        if self.tokenizer.pad_token_id is not None
        else 0
    )

    input_ids = torch.tensor(
        [
            q.prompt_padded(padding_token_id, max_query_length)
            for q in unique_queries
        ]
    ).to(self.device)
    attn_masks = torch.tensor(
        [
            q.attention_mask(max_past_length, max_query_length)
            for q in unique_queries
        ]
    ).to(self.device)
    posn_ids = torch.tensor(
        [q.position_ids(max_past_length, max_query_length) for q in unique_queries]
    ).to(self.device)
    if past_example:
        pasts = [
            [
                torch.cat(
                    (
                        *(
                            q.past_padded(
                                layer,
                                j,
                                max_past_length,
                                past_example[0][0].dtype,
                                self.device,
                                past_example[0][0].shape,
                            )
                            for q in unique_queries
                        ),
                    ),
                    dim=0,
                )
                for j in range(2)
            ]
            for layer in range(len(past_example))
        ]
    else:
        pasts = None

    pasts = DynamicCache.from_legacy_cache(pasts)

    results = self.model(
        input_ids,
        attention_mask=attn_masks,
        position_ids=posn_ids,
        past_key_values=pasts,
        use_cache=pasts is not None,
    )

    assert len(results.logits) == len(unique_queries)

    for i, q in enumerate(unique_queries):
        result = results.logits[i]
        for dup_query in query_groups[tuple(q.prompt)]:
            dup_query.future.set_result(result)

add_query(query, future, past)

Add a query to be evaluated in the next batch.

This method is called internally when a next_token_logprobs request is made.

Parameters:

Name Type Description Default
query list[int]

Token IDs representing the query prompt

required
future Future

Future to store the result in

required
past list[tuple[Tensor]] | None

Past key/value states from previous evaluation, or None if this is a new query

required
Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def add_query(self, query, future, past):
    """Add a query to be evaluated in the next batch.

    This method is called internally when a `next_token_logprobs` request is made.

    Args:
        query (list[int]): Token IDs representing the query prompt
        future (asyncio.Future): Future to store the result in
        past (list[tuple[torch.Tensor]]|None): Past key/value states from previous evaluation,
            or None if this is a new query
    """
    self.queries.append(Query(query, future, past))

    if self.timer:
        self.timer.cancel()
        self.timer = None
    if len(self.queries) >= self.batch_size:
        self.batch_evaluate_queries()
    else:
        self.timer = asyncio.get_running_loop().call_later(
            self.timeout, lambda: self.batch_evaluate_queries()
        )

walk_cache(token_ids)

Walk the cache tree to find the deepest node matching a sequence of tokens.

Parameters:

Name Type Description Default
token_ids list[int]

Sequence of token IDs to follow in the cache tree

required

Returns:

Name Type Description
tuple
  • CacheNode: The deepest node in the cache tree that matches the token sequence
  • int: Number of tokens matched from the start of token_ids
  • list[tuple[torch.Tensor]]|None: Past key/value states from the deepest cached node, or None if no cached states were found
  • int: Base index indicating where the past states start in token_ids
Source code in genlm/backend/llm/hf.py
def walk_cache(self, token_ids):
    """Walk the cache tree to find the deepest node matching a sequence of tokens.

    Args:
        token_ids (list[int]): Sequence of token IDs to follow in the cache tree

    Returns:
        tuple:
            - CacheNode: The deepest node in the cache tree that matches the token sequence
            - int: Number of tokens matched from the start of token_ids
            - list[tuple[torch.Tensor]]|None: Past key/value states from the deepest cached node,
                or None if no cached states were found
            - int: Base index indicating where the past states start in token_ids
    """
    # Walk while tokens can be found
    node = self.cache
    next_token_index = 0

    past = None
    base = 0
    while next_token_index < len(token_ids):
        if node.past_key_values is not None:
            past = node.past_key_values
            base = next_token_index
        if node.has_token(token_ids[next_token_index]):
            node = node.get_token(token_ids[next_token_index])
            next_token_index += 1
        else:
            break

    return node, next_token_index, past, base

next_token_logprobs(token_ids) async

Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with await.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs Tensor

a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in genlm/backend/llm/hf.py
@torch.no_grad()
async def next_token_logprobs(self, token_ids):
    """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`.

    Args:
        token_ids (list[int]): a list of token ids, representing a prompt to the language model.

    Returns:
        logprobs (torch.Tensor): a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    node, next_token_index, past, base = self.walk_cache(token_ids)

    # If we processed all tokens, then we're done.
    if next_token_index == len(token_ids):
        return node.logprobs

    # Create a future with the prompt
    future = asyncio.get_running_loop().create_future()
    self.add_query(token_ids[base:], future, past)
    logits = await future

    # Create new nodes
    node = node.extend_cache(next_token_index, token_ids, logits, base)

    return node.logprobs

next_token_logprobs_sync(token_ids)

Request log probabilities of next token. Not asynchronous, and does not support auto-batching.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs Tensor

a tensor with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in genlm/backend/llm/hf.py
@torch.no_grad()
def next_token_logprobs_sync(self, token_ids):
    """Request log probabilities of next token. Not asynchronous, and does not support auto-batching.

    Args:
        token_ids (list[int]): a list of token ids, representing a prompt to the language model.

    Returns:
        logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    # Walk while tokens can be found
    node, next_token_index, past, base = self.walk_cache(token_ids)

    if next_token_index == len(token_ids):
        return node.logprobs

    logits = self.model(
        torch.tensor([token_ids[base:]]).to(self.device),
        past_key_values=node.past_key_values,
        use_cache=node.past_key_values is not None,
    ).logits[0]

    node = node.extend_cache(next_token_index, token_ids, logits, base)

    return node.logprobs

next_token_logprobs_uncached(token_ids)

Request log probabilities of next token. No KV or output caching, and does not support auto-batching.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs Tensor

a tensor with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in genlm/backend/llm/hf.py
def next_token_logprobs_uncached(self, token_ids):
    """Request log probabilities of next token. No KV or output caching, and does not support auto-batching.

    Args:
        token_ids (list[int]): a list of token ids, representing a prompt to the language model.

    Returns:
        logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    with torch.no_grad():
        logits = self.model(
            torch.tensor([token_ids]).to(self.device),
            past_key_values=None,
            use_cache=False,
        ).logits[0]
        return torch.log_softmax(logits[-1], dim=0)

AsyncLM

Bases: ABC

Abstract base class for asynchronous language models.

This class provides an interface for language models that can generate token probabilities asynchronously. It handles tokenization and vocabulary management.

Parameters:

Name Type Description Default
tokenizer

A Hugging Face tokenizer instance compatible with the language model

required
Source code in genlm/backend/llm/base.py
class AsyncLM(ABC):
    """Abstract base class for asynchronous language models.

    This class provides an interface for language models that can generate token probabilities
    asynchronously. It handles tokenization and vocabulary management.

    Args:
        tokenizer: A Hugging Face tokenizer instance compatible with the language model
    """

    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.byte_vocab, self.str_vocab = decode_vocab(self.tokenizer)

    @abstractmethod
    async def next_token_logprobs(self, token_ids):
        """Request log probabilities of next token asynchronously.

        Args:
            token_ids (list[int]): A list of token IDs representing the prompt.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        pass

    @abstractmethod
    def next_token_logprobs_sync(self, token_ids):
        """Request log probabilities of next token synchronously.

        Args:
            token_ids (list[int]): A list of token IDs representing the prompt.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        pass

    async def batch_next_token_logprobs(self, token_ids_list):
        """Batch request log probabilities for multiple token sequences asynchronously.

        Args:
            token_ids_list (list[list[int]]): A list of token ID lists.

        Returns:
            (torch.Tensor): A tensor of log probability tensors.
        """
        logprobs = await asyncio.gather(
            *[self.next_token_logprobs(token_ids) for token_ids in token_ids_list]
        )

        return torch.stack(logprobs)

    def batch_next_token_logprobs_sync(self, token_ids_list):
        """Batch request log probabilities for multiple token sequences synchronously.

        Args:
            token_ids_list (list[list[int]]): A list of token ID lists.

        Returns:
            (torch.Tensor): A tensor of log probability tensors.
        """
        return torch.stack(
            [self.next_token_logprobs_sync(token_ids) for token_ids in token_ids_list]
        )

    def add_new_lora(self, lora_path, lora_name):
        """Load a LoRA adapter into the base model.

        Args:
            lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
            lora_name (str): Name to assign to the loaded adapter.

        """
        raise NotImplementedError(
            "add_new_lora must be implemented by subclasses"
        )  # pragma: no cover

    def set_lora(self, lora_path, lora_name):
        """Activate a previously loaded LoRA adapter.

        Args:
            lora_name (str): Name of the LoRA adapter to activate.

        """
        raise NotImplementedError(
            "set_lora must be implemented by subclasses"
        )  # pragma: no cover

    def clear_lora(self):
        """
        Deactivate all LoRA adapters.
        """
        raise NotImplementedError(
            "clear_lora must be implemented by subclasses"
        )  # pragma: no cover

    def clear_cache(self):
        """Clear any caches used by the language model. No-op in base class."""
        pass  # pragma: no cover

    async def sample(
        self, prompt_token_ids, max_tokens, eos_token_ids, temperature=1.0, seed=None
    ):
        """Sample from the language model.

        Args:
            prompt_token_ids (list[int]): The token IDs of the prompt.
            eos_token_ids (list[int]): The token IDs of the end-of-sequence tokens.
            temperature (float, optional): The temperature to use to rescale the logits. Defaults to 1.0.
            max_tokens (int): The maximum number of tokens to generate.
            seed (int, optional): The seed for the random number generator. Defaults to None.

        Returns:
            (list[int]): The sampled token IDs.
        """
        if seed is not None:
            generator = torch.Generator()
            generator.manual_seed(seed)
        else:
            generator = None

        generated_token_ids = []
        for _ in range(max_tokens):
            logprobs = await self.next_token_logprobs(
                prompt_token_ids + generated_token_ids
            )
            probs = torch.softmax(logprobs / temperature, dim=-1)
            next_token_id = torch.multinomial(
                probs.cpu() if seed is not None else probs,
                num_samples=1,
                generator=generator,
            ).item()
            if next_token_id in eos_token_ids:
                break
            generated_token_ids.append(next_token_id)

        return generated_token_ids

    async def batch_sample(
        self,
        prompt_token_ids_list,
        max_tokens,
        eos_token_ids,
        temperature=1.0,
        seed=None,
    ):
        """Batch sample from the language model.

        Args:
            prompt_token_ids_list (list[list[int]]): The token IDs of the prompts.
            max_tokens (int): The maximum number of tokens to generate.
            eos_token_ids (list[int]): The token IDs of the end-of-sequence token.
            temperature (float): The temperature to use for the logits.
            seed (int, optional): The seed for the random number generator. Defaults to None.

        Returns:
            (list[list[int]]): The sampled token IDs.
        """
        return await asyncio.gather(
            *[
                self.sample(
                    prompt_token_ids=prompt_token_ids,
                    max_tokens=max_tokens,
                    eos_token_ids=eos_token_ids,
                    temperature=temperature,
                    seed=seed,
                )
                for prompt_token_ids in prompt_token_ids_list
            ]
        )

next_token_logprobs(token_ids) abstractmethod async

Request log probabilities of next token asynchronously.

Parameters:

Name Type Description Default
token_ids list[int]

A list of token IDs representing the prompt.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm/backend/llm/base.py
@abstractmethod
async def next_token_logprobs(self, token_ids):
    """Request log probabilities of next token asynchronously.

    Args:
        token_ids (list[int]): A list of token IDs representing the prompt.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    pass

next_token_logprobs_sync(token_ids) abstractmethod

Request log probabilities of next token synchronously.

Parameters:

Name Type Description Default
token_ids list[int]

A list of token IDs representing the prompt.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm/backend/llm/base.py
@abstractmethod
def next_token_logprobs_sync(self, token_ids):
    """Request log probabilities of next token synchronously.

    Args:
        token_ids (list[int]): A list of token IDs representing the prompt.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    pass

batch_next_token_logprobs(token_ids_list) async

Batch request log probabilities for multiple token sequences asynchronously.

Parameters:

Name Type Description Default
token_ids_list list[list[int]]

A list of token ID lists.

required

Returns:

Type Description
Tensor

A tensor of log probability tensors.

Source code in genlm/backend/llm/base.py
async def batch_next_token_logprobs(self, token_ids_list):
    """Batch request log probabilities for multiple token sequences asynchronously.

    Args:
        token_ids_list (list[list[int]]): A list of token ID lists.

    Returns:
        (torch.Tensor): A tensor of log probability tensors.
    """
    logprobs = await asyncio.gather(
        *[self.next_token_logprobs(token_ids) for token_ids in token_ids_list]
    )

    return torch.stack(logprobs)

batch_next_token_logprobs_sync(token_ids_list)

Batch request log probabilities for multiple token sequences synchronously.

Parameters:

Name Type Description Default
token_ids_list list[list[int]]

A list of token ID lists.

required

Returns:

Type Description
Tensor

A tensor of log probability tensors.

Source code in genlm/backend/llm/base.py
def batch_next_token_logprobs_sync(self, token_ids_list):
    """Batch request log probabilities for multiple token sequences synchronously.

    Args:
        token_ids_list (list[list[int]]): A list of token ID lists.

    Returns:
        (torch.Tensor): A tensor of log probability tensors.
    """
    return torch.stack(
        [self.next_token_logprobs_sync(token_ids) for token_ids in token_ids_list]
    )

add_new_lora(lora_path, lora_name)

Load a LoRA adapter into the base model.

Parameters:

Name Type Description Default
lora_path str

Path to the adapter weights directory or identifier in HuggingFace's model hub.

required
lora_name str

Name to assign to the loaded adapter.

required
Source code in genlm/backend/llm/base.py
def add_new_lora(self, lora_path, lora_name):
    """Load a LoRA adapter into the base model.

    Args:
        lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
        lora_name (str): Name to assign to the loaded adapter.

    """
    raise NotImplementedError(
        "add_new_lora must be implemented by subclasses"
    )  # pragma: no cover

set_lora(lora_path, lora_name)

Activate a previously loaded LoRA adapter.

Parameters:

Name Type Description Default
lora_name str

Name of the LoRA adapter to activate.

required
Source code in genlm/backend/llm/base.py
def set_lora(self, lora_path, lora_name):
    """Activate a previously loaded LoRA adapter.

    Args:
        lora_name (str): Name of the LoRA adapter to activate.

    """
    raise NotImplementedError(
        "set_lora must be implemented by subclasses"
    )  # pragma: no cover

clear_lora()

Deactivate all LoRA adapters.

Source code in genlm/backend/llm/base.py
def clear_lora(self):
    """
    Deactivate all LoRA adapters.
    """
    raise NotImplementedError(
        "clear_lora must be implemented by subclasses"
    )  # pragma: no cover

clear_cache()

Clear any caches used by the language model. No-op in base class.

Source code in genlm/backend/llm/base.py
def clear_cache(self):
    """Clear any caches used by the language model. No-op in base class."""
    pass  # pragma: no cover

sample(prompt_token_ids, max_tokens, eos_token_ids, temperature=1.0, seed=None) async

Sample from the language model.

Parameters:

Name Type Description Default
prompt_token_ids list[int]

The token IDs of the prompt.

required
eos_token_ids list[int]

The token IDs of the end-of-sequence tokens.

required
temperature float

The temperature to use to rescale the logits. Defaults to 1.0.

1.0
max_tokens int

The maximum number of tokens to generate.

required
seed int

The seed for the random number generator. Defaults to None.

None

Returns:

Type Description
list[int]

The sampled token IDs.

Source code in genlm/backend/llm/base.py
async def sample(
    self, prompt_token_ids, max_tokens, eos_token_ids, temperature=1.0, seed=None
):
    """Sample from the language model.

    Args:
        prompt_token_ids (list[int]): The token IDs of the prompt.
        eos_token_ids (list[int]): The token IDs of the end-of-sequence tokens.
        temperature (float, optional): The temperature to use to rescale the logits. Defaults to 1.0.
        max_tokens (int): The maximum number of tokens to generate.
        seed (int, optional): The seed for the random number generator. Defaults to None.

    Returns:
        (list[int]): The sampled token IDs.
    """
    if seed is not None:
        generator = torch.Generator()
        generator.manual_seed(seed)
    else:
        generator = None

    generated_token_ids = []
    for _ in range(max_tokens):
        logprobs = await self.next_token_logprobs(
            prompt_token_ids + generated_token_ids
        )
        probs = torch.softmax(logprobs / temperature, dim=-1)
        next_token_id = torch.multinomial(
            probs.cpu() if seed is not None else probs,
            num_samples=1,
            generator=generator,
        ).item()
        if next_token_id in eos_token_ids:
            break
        generated_token_ids.append(next_token_id)

    return generated_token_ids

batch_sample(prompt_token_ids_list, max_tokens, eos_token_ids, temperature=1.0, seed=None) async

Batch sample from the language model.

Parameters:

Name Type Description Default
prompt_token_ids_list list[list[int]]

The token IDs of the prompts.

required
max_tokens int

The maximum number of tokens to generate.

required
eos_token_ids list[int]

The token IDs of the end-of-sequence token.

required
temperature float

The temperature to use for the logits.

1.0
seed int

The seed for the random number generator. Defaults to None.

None

Returns:

Type Description
list[list[int]]

The sampled token IDs.

Source code in genlm/backend/llm/base.py
async def batch_sample(
    self,
    prompt_token_ids_list,
    max_tokens,
    eos_token_ids,
    temperature=1.0,
    seed=None,
):
    """Batch sample from the language model.

    Args:
        prompt_token_ids_list (list[list[int]]): The token IDs of the prompts.
        max_tokens (int): The maximum number of tokens to generate.
        eos_token_ids (list[int]): The token IDs of the end-of-sequence token.
        temperature (float): The temperature to use for the logits.
        seed (int, optional): The seed for the random number generator. Defaults to None.

    Returns:
        (list[list[int]]): The sampled token IDs.
    """
    return await asyncio.gather(
        *[
            self.sample(
                prompt_token_ids=prompt_token_ids,
                max_tokens=max_tokens,
                eos_token_ids=eos_token_ids,
                temperature=temperature,
                seed=seed,
            )
            for prompt_token_ids in prompt_token_ids_list
        ]
    )

MockAsyncLM

Bases: AsyncLM

Mock implementation of AsyncLM used for testing.

Source code in genlm/backend/llm/base.py
class MockAsyncLM(AsyncLM):
    """Mock implementation of AsyncLM used for testing."""

    def __init__(self, tokenizer):
        """Initialize a `MockAsyncLM` instance.

        Args:
            tokenizer: Hugging Face tokenizer instance
        """
        super().__init__(tokenizer)
        self._rng = np.random.RandomState(42)

    @classmethod
    def from_name(cls, model_name, **kwargs):
        """Create a MockAsyncLM instance over the vocabulary of the model's tokenizer.

        Args:
            model_name (str): Name of pretrained model to load tokenizer from
            **kwargs: Additional arguments passed to `MockAsyncLM` constructor

        Returns:
            (MockAsyncLM): `MockAsyncLM` instance initialized with tokenizer from `model_name`
        """
        from transformers import AutoTokenizer

        return cls(AutoTokenizer.from_pretrained(model_name), **kwargs)

    async def next_token_logprobs(self, token_ids):
        """Get next token log probabilities asynchronously.

        Args:
            token_ids (list[int]): Input token IDs.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        return self._get_logprobs(token_ids)

    def next_token_logprobs_sync(self, token_ids):
        """Get next token log probabilities synchronously.

        Args:
            token_ids (list[int]): Input token IDs.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        return self._get_logprobs(token_ids)

    def _get_logprobs(self, token_ids):
        """Generate random but deterministic log probabilities for given tokens.

        Uses token_ids to seed the random generator, ensuring same inputs produce same outputs.

        Args:
            token_ids (list[int]): Input token IDs.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        seed = sum([(i + 1) * t for i, t in enumerate(token_ids)])
        self._rng.seed(seed)
        logits = torch.from_numpy(
            self._rng.rand(len(self.byte_vocab)).astype(np.float32)
        )
        return torch.log_softmax(logits, dim=-1)

__init__(tokenizer)

Initialize a MockAsyncLM instance.

Parameters:

Name Type Description Default
tokenizer

Hugging Face tokenizer instance

required
Source code in genlm/backend/llm/base.py
def __init__(self, tokenizer):
    """Initialize a `MockAsyncLM` instance.

    Args:
        tokenizer: Hugging Face tokenizer instance
    """
    super().__init__(tokenizer)
    self._rng = np.random.RandomState(42)

from_name(model_name, **kwargs) classmethod

Create a MockAsyncLM instance over the vocabulary of the model's tokenizer.

Parameters:

Name Type Description Default
model_name str

Name of pretrained model to load tokenizer from

required
**kwargs

Additional arguments passed to MockAsyncLM constructor

{}

Returns:

Type Description
MockAsyncLM

MockAsyncLM instance initialized with tokenizer from model_name

Source code in genlm/backend/llm/base.py
@classmethod
def from_name(cls, model_name, **kwargs):
    """Create a MockAsyncLM instance over the vocabulary of the model's tokenizer.

    Args:
        model_name (str): Name of pretrained model to load tokenizer from
        **kwargs: Additional arguments passed to `MockAsyncLM` constructor

    Returns:
        (MockAsyncLM): `MockAsyncLM` instance initialized with tokenizer from `model_name`
    """
    from transformers import AutoTokenizer

    return cls(AutoTokenizer.from_pretrained(model_name), **kwargs)

next_token_logprobs(token_ids) async

Get next token log probabilities asynchronously.

Parameters:

Name Type Description Default
token_ids list[int]

Input token IDs.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm/backend/llm/base.py
async def next_token_logprobs(self, token_ids):
    """Get next token log probabilities asynchronously.

    Args:
        token_ids (list[int]): Input token IDs.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    return self._get_logprobs(token_ids)

next_token_logprobs_sync(token_ids)

Get next token log probabilities synchronously.

Parameters:

Name Type Description Default
token_ids list[int]

Input token IDs.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm/backend/llm/base.py
def next_token_logprobs_sync(self, token_ids):
    """Get next token log probabilities synchronously.

    Args:
        token_ids (list[int]): Input token IDs.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    return self._get_logprobs(token_ids)

AsyncMlxLM

Bases: AsyncLM

Asynchronous MLX-based language model wrapper.

This class provides an async interface to MLX language models with automatic batching, caching, and KV cache management. It extends AsyncLM to provide efficient batched inference with prefix caching.

The model automatically batches concurrent requests and uses a trie-based cache to store computed log probabilities and KV states for reuse.

Source code in genlm/backend/llm/mlx.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
class AsyncMlxLM(AsyncLM):
    """Asynchronous MLX-based language model wrapper.

    This class provides an async interface to MLX language models with
    automatic batching, caching, and KV cache management. It extends
    AsyncLM to provide efficient batched inference with prefix caching.

    The model automatically batches concurrent requests and uses a trie-based
    cache to store computed log probabilities and KV states for reuse.
    """

    def __init__(
        self,
        mlx_lm_model,
        tokenizer,
        batch_size=5,
        timeout=0.001,
        prefill_step_size=2048,
        cache_size=400,
    ):
        """Initialize an `AsyncMlxLM` instance.

        Args:
            mlx_lm_model: The MLX language model instance.
            tokenizer: The tokenizer for encoding/decoding text.
            batch_size (int, optional): Maximum number of queries to batch
                together.
            timeout (float, optional): Maximum time in seconds to wait
                before processing a batch, even if batch_size is not met.
            prefill_step_size (int, optional): Number of tokens to process
                per step during prompt prefilling.
            cache_size (int, optional): Maximum number of KV cache entries
                to keep in memory.
        """
        self.mlx_lm_model = mlx_lm_model
        self.tokenizer = tokenizer
        self.cache = DynamicTokenTrie()
        self.generation_stream = mx.new_stream(mx.default_device())
        self.queries = []
        self.timeout = timeout
        self.timer = None
        self.prefill_step_size = prefill_step_size
        self.cache_size = cache_size

        self.batch_size = batch_size
        self.kv_cachable = self._kv_cachable(self.mlx_lm_model)
        if not self.kv_cachable:
            warnings.warn(
                f"Model {type(self.mlx_lm_model).__name__} does not support KV caching; "
                f"prefix caching will be disabled.",
                UserWarning,
                stacklevel=2,
            )
        super().__init__(tokenizer=self.tokenizer)

    @classmethod
    def from_name(cls, model_name, **kwargs):
        """Create an `AsyncMlxLM` instance from a model name.

        Args:
            model_name (str): Name of the model to load. Can be a Hugging Face
                model identifier or local path.
            **kwargs: Additional arguments passed to `AsyncMlxLM` constructor,
                such as `batch_size`, `timeout`, `prefill_step_size`, `cache_size`.

        Returns:
            AsyncMlxLM: An `AsyncMlxLM` instance with the loaded model and tokenizer.
        """

        model, tokenizer = mlx_lm.load(model_name)
        return cls(model, tokenizer, **kwargs)

    @staticmethod
    def _to_torch(logprobs):
        """Convert MLX arrays into PyTorch tensors."""
        if logprobs.dtype == mx.bfloat16:
            logprobs = logprobs.astype(mx.float16)
        return torch.tensor(logprobs)

    @staticmethod
    def _kv_cachable(mlx_lm_model):
        """Check if an MLX model supports KV cache storage.

        A model is KV-cacheable if all its cache layers are KVCache or
        RotatingKVCache with keep=0.
        """
        if not hasattr(mlx_lm_model, "make_cache"):
            return True
        cache = mlx_lm_model.make_cache()
        return all(
            isinstance(c, KVCache)
            or (isinstance(c, RotatingKVCache) and c.keep == 0)
            for c in cache
        )

    def clear_cache(self):
        """Clear the output cache and MLX device cache.

        This method resets the internal token trie cache and clears
        any cached arrays on the MLX device to free memory.
        """
        if self.cache is not None:
            self.cache = DynamicTokenTrie()
        mx.clear_cache()

    def walk_cache(self, token_ids):
        """Walk the cache tree to find the deepest node matching a sequence of tokens.

        Args:
            token_ids (list[int]): Sequence of token IDs to follow in the cache tree

        Returns:
            tuple: A 5-tuple containing:
                - node: The deepest node in the cache tree that matches
                    the token sequence, irregardless of whether its kv is cached or not
                - next_token_index: Number of tokens matched from the start of token_ids
                - past_kvs: Past key/value states concatenated from cached nodes, or None if no cached states were found
                - kv_node: The cache node where KV states start
                - kv_next_token_index: Number of tokens matched from the start of token_ids for the KV states
        """
        # Walk while tokens can be found
        node = self.cache
        kv_next_token_index = 0
        kv_node = node
        collecting = True
        next_token_index = 0
        past_kvs = []

        while next_token_index < len(token_ids):
            if node.past_key_values is not None and collecting:
                past_kvs.append(node.past_key_values)
                kv_node = node
                kv_next_token_index = next_token_index
            elif next_token_index > 0:
                collecting = False
            if node.has_token(token_ids[next_token_index]):
                node = node.get_token(token_ids[next_token_index])
                next_token_index += 1
            else:
                break

        past_kvs = None if len(past_kvs) == 0 else mx.concatenate(past_kvs, axis=3)

        return node, next_token_index, past_kvs, kv_node, kv_next_token_index

    def cache_kv(self, token_ids):
        """Pre-compute and cache KV states for a given token sequence."""
        query = Query(token_ids, None, None, self.cache, 0)
        self._batch_logits_custom([query])

    def reset_async_queries(self):
        """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing
        to completion."""
        self.queries = []

    def add_to_cache(self, queries, prompt_cache=None, logprobs=None):
        """Add computed log probabilities and KV states to the cache tree."""
        left_paddings = prompt_cache[0].left_padding.tolist()
        for i, query in enumerate(queries):
            token_ids, node, next_token_index = (
                query.prompt,
                query.node,
                query.next_token_index,
            )
            if node is None or next_token_index is None:
                node = self.cache
                next_token_index = 0
            lp = left_paddings[i]
            if prompt_cache is not None and self.kv_cachable:
                keys = [
                    c.keys[i, :, lp + next_token_index : lp + len(token_ids), :]
                    for c in prompt_cache
                ]
                values = [
                    c.values[i, :, lp + next_token_index : lp + len(token_ids), :]
                    for c in prompt_cache
                ]
                keys = mx.stack(keys, axis=0)
                values = mx.stack(values, axis=0)
                keys_values = mx.stack([keys, values], axis=0)
                node.extend_cache(
                    next_token_index, token_ids, logprobs[i], keys_values
                )
            else:
                node.extend_cache(next_token_index, token_ids, logprobs[i])

        self.cache.evict_lru_kv(self.cache_size)

    def _process_kv(self, left_paddings, prompt_cache, pasts=None, step_size=256):
        """Process and integrate past KV cache states into prompt cache.

        This method takes past key-value cache states from the cache tree
        and integrates them into the prompt cache for efficient prefix
        reuse. It handles padding and alignment of cache states across
        different query lengths.

        Args:
            left_paddings (list[int]): Left padding amounts for each query
                in the batch.
            prompt_cache (list): List of cache objects to update with
                past states.
            pasts (list[mx.array], optional): List of past KV cache states,
                one per query.
            step_size (int, optional): Step size for cache size alignment.

        Returns:
            tuple: A 2-tuple containing:
                - list: Updated prompt_cache objects
                - cached_len: Number of tokens that were cached
        """
        if pasts is None or all(past is None for past in pasts):
            return prompt_cache, 0
        max_match_lengths = [0 if past is None else past.shape[3] for past in pasts]
        min_pos_cached = min(
            ml + lp for ml, lp in zip(max_match_lengths, left_paddings)
        )
        cache_grabs = [max(min_pos_cached - lp, 0) for lp in left_paddings]
        non_zero_index = next(
            (i for i, grab in enumerate(cache_grabs) if grab), None
        )
        if non_zero_index is None:
            return prompt_cache, 0
        _, num_layers, N, _, D = pasts[non_zero_index].shape
        cache_size = (step_size + min_pos_cached - 1) // step_size * step_size
        right_paddings = [
            max(cache_size - lp - max_len, 0)
            for lp, max_len in zip(left_paddings, max_match_lengths)
        ]
        padded_pasts = []
        for past, lp, rp in zip(pasts, left_paddings, right_paddings):
            if past is None:
                padded_pasts.append(mx.zeros((2, num_layers, N, cache_size, D)))
            else:
                padded_pasts.append(
                    mx.pad(
                        past[:, :, :, : cache_size - lp, :],
                        ((0, 0), (0, 0), (0, 0), (lp, rp), (0, 0)),
                    )
                )

        padded_pasts = mx.stack(padded_pasts, axis=2)
        for i, cache in enumerate(prompt_cache):
            cache.keys = padded_pasts[0, i]
            cache.values = padded_pasts[1, i]
            cache.offset += min_pos_cached
            cache._idx += min_pos_cached
        return prompt_cache, min_pos_cached

    def _process_prompts(self, queries):
        """Process a batch of prompts and compute next-token log probabilities."""
        inputs = [q.prompt for q in queries]
        pasts = [q.past for q in queries]
        lengths = [len(p) for p in inputs]
        max_length = max(lengths)
        left_padding = [max_length - length for length in lengths]
        prompt_cache = _make_cache(self.mlx_lm_model, left_padding)
        inputs_padded = _left_pad_prompts(inputs, max_length=max_length)

        if self.kv_cachable:
            prompt_cache, cached_len = self._process_kv(
                left_padding, prompt_cache, pasts
            )
        else:
            cached_len = 0
        inputs_padded = inputs_padded[:, cached_len:]

        while inputs_padded.shape[1] > 1:
            n_to_process = min(self.prefill_step_size, inputs_padded.shape[1] - 1)
            self.mlx_lm_model(inputs_padded[:, :n_to_process], cache=prompt_cache)
            mx.eval([c.state for c in prompt_cache])
            inputs_padded = inputs_padded[:, n_to_process:]

        logits = self.mlx_lm_model(inputs_padded, cache=prompt_cache)
        logits = logits[:, -1, :]
        logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
        mx.async_eval(logprobs)

        return logprobs, prompt_cache

    def _batch_logits_custom(
        self,
        queries,
    ):
        """Compute next-token log probabilities for each query in a batch and add to cache.
        Args:
            queries (list[Query]): List of query objects to process.
        Returns:
            logprobs (list[torch.Tensor]): List of normalized log probability tensors."""
        with wired_limit(self.mlx_lm_model, [self.generation_stream]):
            logprobs, prompt_cache = self._process_prompts(queries)
            logprobs = AsyncMlxLM._to_torch(logprobs)
        mx.clear_cache()
        self.add_to_cache(queries, prompt_cache, logprobs)
        return logprobs

    def batch_evaluate_queries(self):
        """Process a batch of queued language model queries."""

        queries, self.queries = self.queries, []
        if len(queries) == 0:
            return

        query_groups = defaultdict(list)
        for query in queries:
            key = tuple(query.prompt)
            query_groups[key].append(query)

        # Use one representative query from each group
        unique_queries = [group[0] for group in query_groups.values()]

        results = self._batch_logits_custom(unique_queries)

        assert len(results) == len(unique_queries)

        for i, q in enumerate(unique_queries):
            for dup_query in query_groups[tuple(q.prompt)]:
                dup_query.future.set_result(results[i])

    def add_query(self, query):
        """Add a query to be evaluated in the next batch and reset the timeout."""
        self.queries.append(query)

        if self.timer:
            self.timer.cancel()
            self.timer = None
        if len(self.queries) >= self.batch_size:
            self.batch_evaluate_queries()
        else:
            self.timer = asyncio.get_running_loop().call_later(
                self.timeout, lambda: self.batch_evaluate_queries()
            )

    async def next_token_logprobs(self, token_ids):
        """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`.

        Args:
            token_ids (list[int]): a list of token ids, representing a prompt to the language model.

        Returns:
            logprobs (torch.Tensor): a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        node, next_token_index, past, kv_node, kv_next_token_index = (
            self.walk_cache(token_ids)
        )
        if next_token_index == len(token_ids) and node.logprobs is not None:
            return node.logprobs

        future = asyncio.get_running_loop().create_future()
        query = Query(token_ids, future, past, kv_node, kv_next_token_index)
        self.add_query(query)
        logprobs = await future

        return logprobs

    def next_token_logprobs_sync(self, token_ids):
        """Request log probabilities of next token synchronously.

        Args:
            token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        node, next_token_index, past, kv_node, kv_next_token_index = (
            self.walk_cache(token_ids)
        )
        if next_token_index == len(token_ids) and node.logprobs is not None:
            return node.logprobs

        query = Query(token_ids, None, past, kv_node, kv_next_token_index)
        logprobs = self._batch_logits_custom([query])[0]

        return logprobs

    async def sample(
        self,
        prompt_token_ids,
        max_tokens,
        eos_token_ids,
        temperature=1.0,
        seed=None,
    ):
        """Sample from the language model.

        Args:
            prompt_token_ids (list[int]): The token IDs of the prompt to
                start generation from.
            max_tokens (int): The maximum number of tokens to generate.
            eos_token_ids (list[int]): The token IDs that signal
                end-of-sequence. Generation stops when one of these is
                sampled.
            temperature (float, optional): The temperature to use for
                sampling. Higher values make the distribution more uniform,
                lower values make it more peaked. Defaults to 1.0.
            seed (int, optional): The seed for the random number generator.
                If provided, sets the random seed before sampling.
                Defaults to None.

        Returns:
            (list[int]): The sampled token IDs.
        """

        if seed is not None:
            mx.random.seed(seed)

        sampler = make_sampler(temp=temperature)
        prompt_token_ids_array = mx.array(prompt_token_ids)
        token_generator = generate_step(
            prompt_token_ids_array,
            self.mlx_lm_model,
            max_tokens=max_tokens,
            sampler=sampler,
        )
        generated_token_ids = []
        for sampled, _ in token_generator:
            if sampled in eos_token_ids:
                break
            generated_token_ids.append(sampled)
        return generated_token_ids

__init__(mlx_lm_model, tokenizer, batch_size=5, timeout=0.001, prefill_step_size=2048, cache_size=400)

Initialize an AsyncMlxLM instance.

Parameters:

Name Type Description Default
mlx_lm_model

The MLX language model instance.

required
tokenizer

The tokenizer for encoding/decoding text.

required
batch_size int

Maximum number of queries to batch together.

5
timeout float

Maximum time in seconds to wait before processing a batch, even if batch_size is not met.

0.001
prefill_step_size int

Number of tokens to process per step during prompt prefilling.

2048
cache_size int

Maximum number of KV cache entries to keep in memory.

400
Source code in genlm/backend/llm/mlx.py
def __init__(
    self,
    mlx_lm_model,
    tokenizer,
    batch_size=5,
    timeout=0.001,
    prefill_step_size=2048,
    cache_size=400,
):
    """Initialize an `AsyncMlxLM` instance.

    Args:
        mlx_lm_model: The MLX language model instance.
        tokenizer: The tokenizer for encoding/decoding text.
        batch_size (int, optional): Maximum number of queries to batch
            together.
        timeout (float, optional): Maximum time in seconds to wait
            before processing a batch, even if batch_size is not met.
        prefill_step_size (int, optional): Number of tokens to process
            per step during prompt prefilling.
        cache_size (int, optional): Maximum number of KV cache entries
            to keep in memory.
    """
    self.mlx_lm_model = mlx_lm_model
    self.tokenizer = tokenizer
    self.cache = DynamicTokenTrie()
    self.generation_stream = mx.new_stream(mx.default_device())
    self.queries = []
    self.timeout = timeout
    self.timer = None
    self.prefill_step_size = prefill_step_size
    self.cache_size = cache_size

    self.batch_size = batch_size
    self.kv_cachable = self._kv_cachable(self.mlx_lm_model)
    if not self.kv_cachable:
        warnings.warn(
            f"Model {type(self.mlx_lm_model).__name__} does not support KV caching; "
            f"prefix caching will be disabled.",
            UserWarning,
            stacklevel=2,
        )
    super().__init__(tokenizer=self.tokenizer)

from_name(model_name, **kwargs) classmethod

Create an AsyncMlxLM instance from a model name.

Parameters:

Name Type Description Default
model_name str

Name of the model to load. Can be a Hugging Face model identifier or local path.

required
**kwargs

Additional arguments passed to AsyncMlxLM constructor, such as batch_size, timeout, prefill_step_size, cache_size.

{}

Returns:

Name Type Description
AsyncMlxLM

An AsyncMlxLM instance with the loaded model and tokenizer.

Source code in genlm/backend/llm/mlx.py
@classmethod
def from_name(cls, model_name, **kwargs):
    """Create an `AsyncMlxLM` instance from a model name.

    Args:
        model_name (str): Name of the model to load. Can be a Hugging Face
            model identifier or local path.
        **kwargs: Additional arguments passed to `AsyncMlxLM` constructor,
            such as `batch_size`, `timeout`, `prefill_step_size`, `cache_size`.

    Returns:
        AsyncMlxLM: An `AsyncMlxLM` instance with the loaded model and tokenizer.
    """

    model, tokenizer = mlx_lm.load(model_name)
    return cls(model, tokenizer, **kwargs)

clear_cache()

Clear the output cache and MLX device cache.

This method resets the internal token trie cache and clears any cached arrays on the MLX device to free memory.

Source code in genlm/backend/llm/mlx.py
def clear_cache(self):
    """Clear the output cache and MLX device cache.

    This method resets the internal token trie cache and clears
    any cached arrays on the MLX device to free memory.
    """
    if self.cache is not None:
        self.cache = DynamicTokenTrie()
    mx.clear_cache()

walk_cache(token_ids)

Walk the cache tree to find the deepest node matching a sequence of tokens.

Parameters:

Name Type Description Default
token_ids list[int]

Sequence of token IDs to follow in the cache tree

required

Returns:

Name Type Description
tuple

A 5-tuple containing: - node: The deepest node in the cache tree that matches the token sequence, irregardless of whether its kv is cached or not - next_token_index: Number of tokens matched from the start of token_ids - past_kvs: Past key/value states concatenated from cached nodes, or None if no cached states were found - kv_node: The cache node where KV states start - kv_next_token_index: Number of tokens matched from the start of token_ids for the KV states

Source code in genlm/backend/llm/mlx.py
def walk_cache(self, token_ids):
    """Walk the cache tree to find the deepest node matching a sequence of tokens.

    Args:
        token_ids (list[int]): Sequence of token IDs to follow in the cache tree

    Returns:
        tuple: A 5-tuple containing:
            - node: The deepest node in the cache tree that matches
                the token sequence, irregardless of whether its kv is cached or not
            - next_token_index: Number of tokens matched from the start of token_ids
            - past_kvs: Past key/value states concatenated from cached nodes, or None if no cached states were found
            - kv_node: The cache node where KV states start
            - kv_next_token_index: Number of tokens matched from the start of token_ids for the KV states
    """
    # Walk while tokens can be found
    node = self.cache
    kv_next_token_index = 0
    kv_node = node
    collecting = True
    next_token_index = 0
    past_kvs = []

    while next_token_index < len(token_ids):
        if node.past_key_values is not None and collecting:
            past_kvs.append(node.past_key_values)
            kv_node = node
            kv_next_token_index = next_token_index
        elif next_token_index > 0:
            collecting = False
        if node.has_token(token_ids[next_token_index]):
            node = node.get_token(token_ids[next_token_index])
            next_token_index += 1
        else:
            break

    past_kvs = None if len(past_kvs) == 0 else mx.concatenate(past_kvs, axis=3)

    return node, next_token_index, past_kvs, kv_node, kv_next_token_index

cache_kv(token_ids)

Pre-compute and cache KV states for a given token sequence.

Source code in genlm/backend/llm/mlx.py
def cache_kv(self, token_ids):
    """Pre-compute and cache KV states for a given token sequence."""
    query = Query(token_ids, None, None, self.cache, 0)
    self._batch_logits_custom([query])

reset_async_queries()

Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing to completion.

Source code in genlm/backend/llm/mlx.py
def reset_async_queries(self):
    """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing
    to completion."""
    self.queries = []

add_to_cache(queries, prompt_cache=None, logprobs=None)

Add computed log probabilities and KV states to the cache tree.

Source code in genlm/backend/llm/mlx.py
def add_to_cache(self, queries, prompt_cache=None, logprobs=None):
    """Add computed log probabilities and KV states to the cache tree."""
    left_paddings = prompt_cache[0].left_padding.tolist()
    for i, query in enumerate(queries):
        token_ids, node, next_token_index = (
            query.prompt,
            query.node,
            query.next_token_index,
        )
        if node is None or next_token_index is None:
            node = self.cache
            next_token_index = 0
        lp = left_paddings[i]
        if prompt_cache is not None and self.kv_cachable:
            keys = [
                c.keys[i, :, lp + next_token_index : lp + len(token_ids), :]
                for c in prompt_cache
            ]
            values = [
                c.values[i, :, lp + next_token_index : lp + len(token_ids), :]
                for c in prompt_cache
            ]
            keys = mx.stack(keys, axis=0)
            values = mx.stack(values, axis=0)
            keys_values = mx.stack([keys, values], axis=0)
            node.extend_cache(
                next_token_index, token_ids, logprobs[i], keys_values
            )
        else:
            node.extend_cache(next_token_index, token_ids, logprobs[i])

    self.cache.evict_lru_kv(self.cache_size)

batch_evaluate_queries()

Process a batch of queued language model queries.

Source code in genlm/backend/llm/mlx.py
def batch_evaluate_queries(self):
    """Process a batch of queued language model queries."""

    queries, self.queries = self.queries, []
    if len(queries) == 0:
        return

    query_groups = defaultdict(list)
    for query in queries:
        key = tuple(query.prompt)
        query_groups[key].append(query)

    # Use one representative query from each group
    unique_queries = [group[0] for group in query_groups.values()]

    results = self._batch_logits_custom(unique_queries)

    assert len(results) == len(unique_queries)

    for i, q in enumerate(unique_queries):
        for dup_query in query_groups[tuple(q.prompt)]:
            dup_query.future.set_result(results[i])

add_query(query)

Add a query to be evaluated in the next batch and reset the timeout.

Source code in genlm/backend/llm/mlx.py
def add_query(self, query):
    """Add a query to be evaluated in the next batch and reset the timeout."""
    self.queries.append(query)

    if self.timer:
        self.timer.cancel()
        self.timer = None
    if len(self.queries) >= self.batch_size:
        self.batch_evaluate_queries()
    else:
        self.timer = asyncio.get_running_loop().call_later(
            self.timeout, lambda: self.batch_evaluate_queries()
        )

next_token_logprobs(token_ids) async

Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with await.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs Tensor

a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in genlm/backend/llm/mlx.py
async def next_token_logprobs(self, token_ids):
    """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`.

    Args:
        token_ids (list[int]): a list of token ids, representing a prompt to the language model.

    Returns:
        logprobs (torch.Tensor): a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    node, next_token_index, past, kv_node, kv_next_token_index = (
        self.walk_cache(token_ids)
    )
    if next_token_index == len(token_ids) and node.logprobs is not None:
        return node.logprobs

    future = asyncio.get_running_loop().create_future()
    query = Query(token_ids, future, past, kv_node, kv_next_token_index)
    self.add_query(query)
    logprobs = await future

    return logprobs

next_token_logprobs_sync(token_ids)

Request log probabilities of next token synchronously.

Parameters:

Name Type Description Default
token_ids list[int]

A list of token IDs, representing a prompt to the language model.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm/backend/llm/mlx.py
def next_token_logprobs_sync(self, token_ids):
    """Request log probabilities of next token synchronously.

    Args:
        token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    node, next_token_index, past, kv_node, kv_next_token_index = (
        self.walk_cache(token_ids)
    )
    if next_token_index == len(token_ids) and node.logprobs is not None:
        return node.logprobs

    query = Query(token_ids, None, past, kv_node, kv_next_token_index)
    logprobs = self._batch_logits_custom([query])[0]

    return logprobs

sample(prompt_token_ids, max_tokens, eos_token_ids, temperature=1.0, seed=None) async

Sample from the language model.

Parameters:

Name Type Description Default
prompt_token_ids list[int]

The token IDs of the prompt to start generation from.

required
max_tokens int

The maximum number of tokens to generate.

required
eos_token_ids list[int]

The token IDs that signal end-of-sequence. Generation stops when one of these is sampled.

required
temperature float

The temperature to use for sampling. Higher values make the distribution more uniform, lower values make it more peaked. Defaults to 1.0.

1.0
seed int

The seed for the random number generator. If provided, sets the random seed before sampling. Defaults to None.

None

Returns:

Type Description
list[int]

The sampled token IDs.

Source code in genlm/backend/llm/mlx.py
async def sample(
    self,
    prompt_token_ids,
    max_tokens,
    eos_token_ids,
    temperature=1.0,
    seed=None,
):
    """Sample from the language model.

    Args:
        prompt_token_ids (list[int]): The token IDs of the prompt to
            start generation from.
        max_tokens (int): The maximum number of tokens to generate.
        eos_token_ids (list[int]): The token IDs that signal
            end-of-sequence. Generation stops when one of these is
            sampled.
        temperature (float, optional): The temperature to use for
            sampling. Higher values make the distribution more uniform,
            lower values make it more peaked. Defaults to 1.0.
        seed (int, optional): The seed for the random number generator.
            If provided, sets the random seed before sampling.
            Defaults to None.

    Returns:
        (list[int]): The sampled token IDs.
    """

    if seed is not None:
        mx.random.seed(seed)

    sampler = make_sampler(temp=temperature)
    prompt_token_ids_array = mx.array(prompt_token_ids)
    token_generator = generate_step(
        prompt_token_ids_array,
        self.mlx_lm_model,
        max_tokens=max_tokens,
        sampler=sampler,
    )
    generated_token_ids = []
    for sampled, _ in token_generator:
        if sampled in eos_token_ids:
            break
        generated_token_ids.append(sampled)
    return generated_token_ids

AsyncSGLTransformer

Bases: AsyncLM

Asynchronous wrapper around a SGLang inference engine.

This class provides an asynchronous interface to SGLang inference engine with automatic batching and caching. It extends AsyncLM to provide efficient batched inference.

The model automatically batches concurrent requests and uses a cache to store computed log probabilities for reuse.

Source code in genlm/backend/llm/sgl.py
class AsyncSGLTransformer(AsyncLM):
    """Asynchronous wrapper around a SGLang inference engine.

    This class provides an asynchronous interface to SGLang inference engine with
    automatic batching and caching. It extends AsyncLM to provide efficient
    batched inference.

    The model automatically batches concurrent requests and uses a cache to store
    computed log probabilities for reuse.
    """

    def __init__(self, sgl_model, cache_size=0, cache_opts=None):
        """Initialize an `AsyncSGLTransformer` instance.

        Args:
            sgl_model: The SGLang inference engine instance.
            cache_size (int, optional): Maximum number of log probabilities to keep in memory.
            cache_opts (dict, optional): Additional configuration options for the cache.
        """
        self.model = sgl_model
        self.tokenizer = sgl_model.tokenizer

        cache_opts = {} if cache_opts is None else cache_opts
        self.cache = (
            OutputCache(maxsize=cache_size, **cache_opts)
            if cache_size > 0
            else None
        )

        self._queue: Optional[asyncio.Queue] = None
        self._task: Optional[asyncio.Task] = None

        self._pending: Dict[Tuple[int, ...], List[asyncio.Future]] = {}
        self._inflight: Dict[Tuple[int, ...], Request] = {}

        self._rid_to_token_ids: Dict[str, Tuple[int, ...]] = {}

        super().__init__(tokenizer=self.tokenizer)

    @classmethod
    def from_name(cls, model_id, engine_opts=None, gpu_id=0, **kwargs):
        """Create an `AsyncSGLTransformer` instance from a model name.

        Args:
            model_id (str): The name of the model to load.
            engine_opts (dict, optional): Additional configuration options for the SGLang inference engine.
            gpu_id (int, optional): The GPU ID to use for the inference engine.
            **kwargs: Additional arguments passed to the `AsyncSGLTransformer` constructor.

        Returns:
            (AsyncSGLTransformer): An initialized `AsyncSGLTransformer` instance.
        """
        _engine_opts = {
            "sampling_backend": "pytorch",
            "skip_tokenizer_init": False,
            "model_path": model_id,
            "grammar_backend": "none",
            "allow_auto_truncate": False,
            "disable_overlap_schedule": False,
            "mem_fraction_static": 0.9,  # default value is 0.9
        }
        if engine_opts:
            _engine_opts.update(engine_opts)
        server_args = ServerArgs(**_engine_opts)
        port_args = PortArgs.init_new(server_args)
        mod = Scheduler(server_args, port_args, gpu_id, 0, 0, 0, 0)
        mod.result_queue = deque()
        return cls(mod, **kwargs)

    def clear_cache(self):
        """Clear the logprobs output cache."""
        if self.cache:
            self.cache.clear()

    def clear_kv_cache(self):
        """Clear the SGLang cache."""
        return self.model.flush_cache()

    def reset_async_queries(self):
        """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing
        to completion."""

        for waiters in self._pending.values():
            for fut in waiters:
                fut.cancel()

        self._pending.clear()
        self._inflight.clear()
        self._rid_to_token_ids.clear()

        if self._queue:
            while True:
                try:
                    _, fut = self._queue.get_nowait()
                    fut.cancel()
                except asyncio.QueueEmpty:
                    break

        if self._task and not self._task.done():
            self._task.cancel()
        self._task = None
        self._queue = None

    def _start(self):
        """Start the background loop if it is not already running."""
        if not self._task or self._task.done():
            self._queue = asyncio.Queue()
            self._task = asyncio.create_task(self._background_loop())

    def _queue_request(self, token_ids: Tuple[int]):
        """Queue a request to the SGLang inference engine.

        Args:
            token_ids (tuple[int]): The token IDs of the request.

        Returns:
            (asyncio.Future): A future that will be set with the result of the request.
        """
        if not self._task or self._task.done():
            self._start()
        fut = asyncio.get_running_loop().create_future()
        self._queue.put_nowait((token_ids, fut))
        return fut

    async def next_token_logprobs(self, token_ids: List[int]):
        """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`.

        Args:
            token_ids (list[int]): a list of token ids, representing a prompt to the language model.

        Returns:
            logprobs (torch.Tensor): a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        key = tuple(token_ids)

        if self.cache is not None and key in self.cache:
            return self.cache[key]

        out = await self._queue_request(key)

        if self.cache is not None:
            self.cache[key] = out

        return out

    def next_token_logprobs_sync(self, token_ids: List[int]):
        """Request log probabilities of next token synchronously.

        Args:
            token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        return self.batch_next_token_logprobs_sync([token_ids])[0]

    def batch_next_token_logprobs_sync(self, token_ids_list: List[List[int]]):
        """Request log probabilities of next tokens in a batch synchronously.

        Args:
            token_ids_list (list[list[int]]): A list of token ID lists, each representing a prompt.

        Returns:
            (torch.Tensor): A tensor of normalized log probability tensors.
        """
        results = {}
        to_compute = []

        for token_ids in token_ids_list:
            if not token_ids:
                raise ValueError("Token ids must not be empty")
            key = tuple(token_ids)
            if self.cache is not None and key in self.cache:
                results[key] = self.cache[key]
            elif key not in results:
                to_compute.append(key)
                results[key] = None

        if to_compute:
            requests = []
            for key in to_compute:
                req = _make_request(key)
                self._rid_to_token_ids[req.rid] = key
                requests.append(req)

            for key, logprobs in self._batch_evaluate(requests):
                results[key] = logprobs
                if self.cache is not None:
                    self.cache[key] = logprobs

        return torch.stack([results[tuple(t)] for t in token_ids_list])

    def _register(self, token_ids: Tuple[int], future: asyncio.Future):
        """Register a request with the SGLang inference engine.

        Args:
            token_ids (Tuple[int]): The token IDs of the request.
            future (asyncio.Future): A future that will be set with the result of the request.

        Returns:
            (Request | None): The Request object that was registered, or None if the request future was cancelled.
        """
        if future.cancelled():
            return None

        key = tuple(token_ids)

        self._pending.setdefault(key, []).append(future)

        if key in self._inflight:
            return None

        req = _make_request(token_ids)
        self._rid_to_token_ids[req.rid] = key
        self._inflight[key] = req
        return req

    async def _drain_queue(self) -> List[Request]:
        """Wait for at least one item, then drain all available items from the queue."""
        assert self._queue is not None

        requests = []

        # Wait for at least one item
        token_ids, future = await self._queue.get()
        req = self._register(token_ids, future)
        if req is not None:
            requests.append(req)

        while True:
            try:
                token_ids, future = self._queue.get_nowait()
                req = self._register(token_ids, future)
                if req is not None:
                    requests.append(req)
            except asyncio.QueueEmpty:
                break

        return requests

    def _batch_evaluate(self, requests: List[Request]):
        """Evaluate a batch of requests and return the token IDs and log probabilities."""
        if not requests:
            return  # pragma: no cover

        self.model.process_input_requests(requests)

        while batch := self.model.get_next_batch_to_run():
            with torch.inference_mode():
                batch_result = self.model.run_batch(batch)
                self.model.process_batch_result(batch, batch_result)
                logprobs = torch.log_softmax(
                    batch_result.logits_output.next_token_logits, dim=-1
                ).to("cpu")

                for i, req in enumerate(batch.reqs):
                    if req.finished():
                        token_ids = self._rid_to_token_ids.pop(req.rid, None)
                        if token_ids is None:
                            continue  # pragma: no cover
                        yield token_ids, logprobs[i]

    async def _background_loop(self):
        """Background task that processes queued requests from the queue."""
        assert self._queue is not None
        try:
            while True:
                requests = await self._drain_queue()
                for token_ids, logprobs in self._batch_evaluate(requests):
                    waiters = self._pending.pop(token_ids, [])
                    self._inflight.pop(token_ids, None)
                    for f in waiters:
                        f.set_result(logprobs)

        except asyncio.CancelledError:
            raise
        except Exception as e:
            for waiters in self._pending.values():
                for f in waiters:
                    f.set_exception(e)
            self._pending.clear()
            self._inflight.clear()
            self._rid_to_token_ids.clear()
            raise

    def _cleanup_engine(self):
        """Clean up the SGLang inference engine and distributed environment."""
        if getattr(self, "model", None) is None:
            return  # pragma: no cover
        try:
            self.reset_async_queries()
            destroy_model_parallel()
            destroy_distributed_environment()
        except Exception:  # pragma: no cover
            pass  # pragma: no cover

    def __del__(self):  # pragma: no cover
        """Clean up the SGLang inference engine when the instance is deleted."""
        self._cleanup_engine()

__init__(sgl_model, cache_size=0, cache_opts=None)

Initialize an AsyncSGLTransformer instance.

Parameters:

Name Type Description Default
sgl_model

The SGLang inference engine instance.

required
cache_size int

Maximum number of log probabilities to keep in memory.

0
cache_opts dict

Additional configuration options for the cache.

None
Source code in genlm/backend/llm/sgl.py
def __init__(self, sgl_model, cache_size=0, cache_opts=None):
    """Initialize an `AsyncSGLTransformer` instance.

    Args:
        sgl_model: The SGLang inference engine instance.
        cache_size (int, optional): Maximum number of log probabilities to keep in memory.
        cache_opts (dict, optional): Additional configuration options for the cache.
    """
    self.model = sgl_model
    self.tokenizer = sgl_model.tokenizer

    cache_opts = {} if cache_opts is None else cache_opts
    self.cache = (
        OutputCache(maxsize=cache_size, **cache_opts)
        if cache_size > 0
        else None
    )

    self._queue: Optional[asyncio.Queue] = None
    self._task: Optional[asyncio.Task] = None

    self._pending: Dict[Tuple[int, ...], List[asyncio.Future]] = {}
    self._inflight: Dict[Tuple[int, ...], Request] = {}

    self._rid_to_token_ids: Dict[str, Tuple[int, ...]] = {}

    super().__init__(tokenizer=self.tokenizer)

from_name(model_id, engine_opts=None, gpu_id=0, **kwargs) classmethod

Create an AsyncSGLTransformer instance from a model name.

Parameters:

Name Type Description Default
model_id str

The name of the model to load.

required
engine_opts dict

Additional configuration options for the SGLang inference engine.

None
gpu_id int

The GPU ID to use for the inference engine.

0
**kwargs

Additional arguments passed to the AsyncSGLTransformer constructor.

{}

Returns:

Type Description
AsyncSGLTransformer

An initialized AsyncSGLTransformer instance.

Source code in genlm/backend/llm/sgl.py
@classmethod
def from_name(cls, model_id, engine_opts=None, gpu_id=0, **kwargs):
    """Create an `AsyncSGLTransformer` instance from a model name.

    Args:
        model_id (str): The name of the model to load.
        engine_opts (dict, optional): Additional configuration options for the SGLang inference engine.
        gpu_id (int, optional): The GPU ID to use for the inference engine.
        **kwargs: Additional arguments passed to the `AsyncSGLTransformer` constructor.

    Returns:
        (AsyncSGLTransformer): An initialized `AsyncSGLTransformer` instance.
    """
    _engine_opts = {
        "sampling_backend": "pytorch",
        "skip_tokenizer_init": False,
        "model_path": model_id,
        "grammar_backend": "none",
        "allow_auto_truncate": False,
        "disable_overlap_schedule": False,
        "mem_fraction_static": 0.9,  # default value is 0.9
    }
    if engine_opts:
        _engine_opts.update(engine_opts)
    server_args = ServerArgs(**_engine_opts)
    port_args = PortArgs.init_new(server_args)
    mod = Scheduler(server_args, port_args, gpu_id, 0, 0, 0, 0)
    mod.result_queue = deque()
    return cls(mod, **kwargs)

clear_cache()

Clear the logprobs output cache.

Source code in genlm/backend/llm/sgl.py
def clear_cache(self):
    """Clear the logprobs output cache."""
    if self.cache:
        self.cache.clear()

clear_kv_cache()

Clear the SGLang cache.

Source code in genlm/backend/llm/sgl.py
def clear_kv_cache(self):
    """Clear the SGLang cache."""
    return self.model.flush_cache()

reset_async_queries()

Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing to completion.

Source code in genlm/backend/llm/sgl.py
def reset_async_queries(self):
    """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing
    to completion."""

    for waiters in self._pending.values():
        for fut in waiters:
            fut.cancel()

    self._pending.clear()
    self._inflight.clear()
    self._rid_to_token_ids.clear()

    if self._queue:
        while True:
            try:
                _, fut = self._queue.get_nowait()
                fut.cancel()
            except asyncio.QueueEmpty:
                break

    if self._task and not self._task.done():
        self._task.cancel()
    self._task = None
    self._queue = None

next_token_logprobs(token_ids) async

Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with await.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs Tensor

a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in genlm/backend/llm/sgl.py
async def next_token_logprobs(self, token_ids: List[int]):
    """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`.

    Args:
        token_ids (list[int]): a list of token ids, representing a prompt to the language model.

    Returns:
        logprobs (torch.Tensor): a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    key = tuple(token_ids)

    if self.cache is not None and key in self.cache:
        return self.cache[key]

    out = await self._queue_request(key)

    if self.cache is not None:
        self.cache[key] = out

    return out

next_token_logprobs_sync(token_ids)

Request log probabilities of next token synchronously.

Parameters:

Name Type Description Default
token_ids list[int]

A list of token IDs, representing a prompt to the language model.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm/backend/llm/sgl.py
def next_token_logprobs_sync(self, token_ids: List[int]):
    """Request log probabilities of next token synchronously.

    Args:
        token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    return self.batch_next_token_logprobs_sync([token_ids])[0]

batch_next_token_logprobs_sync(token_ids_list)

Request log probabilities of next tokens in a batch synchronously.

Parameters:

Name Type Description Default
token_ids_list list[list[int]]

A list of token ID lists, each representing a prompt.

required

Returns:

Type Description
Tensor

A tensor of normalized log probability tensors.

Source code in genlm/backend/llm/sgl.py
def batch_next_token_logprobs_sync(self, token_ids_list: List[List[int]]):
    """Request log probabilities of next tokens in a batch synchronously.

    Args:
        token_ids_list (list[list[int]]): A list of token ID lists, each representing a prompt.

    Returns:
        (torch.Tensor): A tensor of normalized log probability tensors.
    """
    results = {}
    to_compute = []

    for token_ids in token_ids_list:
        if not token_ids:
            raise ValueError("Token ids must not be empty")
        key = tuple(token_ids)
        if self.cache is not None and key in self.cache:
            results[key] = self.cache[key]
        elif key not in results:
            to_compute.append(key)
            results[key] = None

    if to_compute:
        requests = []
        for key in to_compute:
            req = _make_request(key)
            self._rid_to_token_ids[req.rid] = key
            requests.append(req)

        for key, logprobs in self._batch_evaluate(requests):
            results[key] = logprobs
            if self.cache is not None:
                self.cache[key] = logprobs

    return torch.stack([results[tuple(t)] for t in token_ids_list])

__del__()

Clean up the SGLang inference engine when the instance is deleted.

Source code in genlm/backend/llm/sgl.py
def __del__(self):  # pragma: no cover
    """Clean up the SGLang inference engine when the instance is deleted."""
    self._cleanup_engine()

load_model_by_name(name, backend=None, llm_opts=None)

Load a language model by name.

Parameters:

Name Type Description Default
name str

Hugging Face model name (e.g. "gpt2", "meta-llama/Llama-3.2-1B-Instruct")

required
backend str

Backend to use for inference. Can be "vllm", "hf", "mlx", "sgl", or "mock". If None, defaults to "vllm" if CUDA is available, otherwise "hf".

None
llm_opts dict

Additional options to pass to the backend constructor. See AsyncVirtualLM and AsyncTransformer documentation for details.

None

Returns:

Type Description
AsyncLM

An asynchronous language model.

Raises:

Type Description
ValueError

If an invalid backend is specified.

Source code in genlm/backend/llm/__init__.py
def load_model_by_name(name, backend=None, llm_opts=None):
    """Load a language model by name.

    Args:
        name (str): Hugging Face model name (e.g. "gpt2", "meta-llama/Llama-3.2-1B-Instruct")
        backend (str, optional): Backend to use for inference. Can be "vllm", "hf", "mlx", "sgl", or "mock".
            If None, defaults to "vllm" if CUDA is available, otherwise "hf".
        llm_opts (dict, optional): Additional options to pass to the backend constructor.
            See AsyncVirtualLM and AsyncTransformer documentation for details.

    Returns:
        (AsyncLM): An asynchronous language model.

    Raises:
        (ValueError): If an invalid backend is specified.
    """
    if backend is None:
        backend = "vllm" if torch.cuda.is_available() else "hf"

    if llm_opts is None:
        llm_opts = {}

    if backend == "vllm":
        return AsyncVirtualLM.from_name(name, **llm_opts)
    elif backend == "hf":
        return AsyncTransformer.from_name(name, **llm_opts)
    elif backend == "mock":
        return MockAsyncLM.from_name(name, **llm_opts)
    elif backend == "mlx":
        return AsyncMlxLM.from_name(name, **llm_opts)
    elif backend == "sgl":
        return AsyncSGLTransformer.from_name(name, **llm_opts)
    else:
        raise ValueError(f"Invalid backend: {backend}")