Skip to content

vllm

GlobalLogprobsCapture

Bases: LogitsProcessor

A global logits processor that captures full vocabulary logprobs.

This processor is injected once into the vLLM v1 engine and records the log probabilities for the most recent sampling step, as a single [batch_size, vocab_size] tensor.

Semantics:

  • :meth:apply is invoked by the v1 sampler exactly once per decode step across a batch of prompts, so the captured tensor always reflects the final token-position logprobs for every prompt in that batch.
  • It does NOT retain history. Each :meth:apply call overwrites _captured_batch. This is intentional: for the next_token_logprobs paths in :class:AsyncVirtualLM, every generate is issued with max_tokens=1 and preceded by :meth:clear, so exactly one decode step runs and the overwrite never hides information. For sampling paths (:meth:sample, :meth:batch_sample) max_tokens > 1, :meth:apply fires once per step, and the final-step capture is correct but earlier steps are discarded - callers of those methods don't read _captured_batch anyway.
  • Concurrent reads/writes are serialized by _lock, so a consumer thread calling :meth:get_logprobs never observes a half-written tensor.
Source code in genlm/backend/llm/vllm.py
class GlobalLogprobsCapture(LogitsProcessor):  # pragma: no cover
    """A global logits processor that captures full vocabulary logprobs.

    This processor is injected once into the vLLM v1 engine and records
    the log probabilities for *the most recent* sampling step, as a
    single ``[batch_size, vocab_size]`` tensor.

    Semantics:

    * :meth:`apply` is invoked by the v1 sampler exactly once per decode
      step across a batch of prompts, so the captured tensor always
      reflects the final token-position logprobs for every prompt in
      that batch.
    * It does NOT retain history. Each :meth:`apply` call overwrites
      ``_captured_batch``. This is intentional: for the
      ``next_token_logprobs`` paths in :class:`AsyncVirtualLM`, every
      ``generate`` is issued with ``max_tokens=1`` and preceded by
      :meth:`clear`, so exactly one decode step runs and the overwrite
      never hides information. For sampling paths (:meth:`sample`,
      :meth:`batch_sample`) ``max_tokens > 1``, :meth:`apply` fires
      once per step, and the final-step capture is correct but earlier
      steps are discarded - callers of those methods don't read
      ``_captured_batch`` anyway.
    * Concurrent reads/writes are serialized by ``_lock``, so a
      consumer thread calling :meth:`get_logprobs` never observes a
      half-written tensor.
    """

    def __init__(self):
        self._captured_batch = None  # [batch_size, vocab_size] tensor
        self._lock = threading.Lock()

    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        """Capture logprobs and pass through logits unchanged.

        Overwrites any previously captured batch; see class docstring.
        """
        # Do the clone outside the critical section so readers aren't blocked
        # on the full [batch, vocab] copy.
        captured = torch.log_softmax(logits, dim=-1, dtype=logits.dtype).clone()
        with self._lock:
            self._captured_batch = captured
        return logits

    def is_argmax_invariant(self) -> bool:
        """Return True since we don't modify logits."""
        return True

    def update_state(self, batch_update) -> None:
        """No state updates needed."""
        pass

    def get_logprobs(self, batch_index=0):
        """Get captured logprobs for a batch index."""
        with self._lock:
            if self._captured_batch is None:
                return None
            if batch_index >= self._captured_batch.shape[0]:
                return None
            return self._captured_batch[batch_index].clone()

    def get_all_logprobs(self):
        """Get all captured logprobs as a batch tensor."""
        with self._lock:
            if self._captured_batch is None:
                return None
            return self._captured_batch.clone()

    def clear(self):
        """Clear captured logprobs."""
        with self._lock:
            self._captured_batch = None

apply(logits)

Capture logprobs and pass through logits unchanged.

Overwrites any previously captured batch; see class docstring.

Source code in genlm/backend/llm/vllm.py
def apply(self, logits: torch.Tensor) -> torch.Tensor:
    """Capture logprobs and pass through logits unchanged.

    Overwrites any previously captured batch; see class docstring.
    """
    # Do the clone outside the critical section so readers aren't blocked
    # on the full [batch, vocab] copy.
    captured = torch.log_softmax(logits, dim=-1, dtype=logits.dtype).clone()
    with self._lock:
        self._captured_batch = captured
    return logits

is_argmax_invariant()

Return True since we don't modify logits.

Source code in genlm/backend/llm/vllm.py
def is_argmax_invariant(self) -> bool:
    """Return True since we don't modify logits."""
    return True

update_state(batch_update)

No state updates needed.

Source code in genlm/backend/llm/vllm.py
def update_state(self, batch_update) -> None:
    """No state updates needed."""
    pass

get_logprobs(batch_index=0)

Get captured logprobs for a batch index.

Source code in genlm/backend/llm/vllm.py
def get_logprobs(self, batch_index=0):
    """Get captured logprobs for a batch index."""
    with self._lock:
        if self._captured_batch is None:
            return None
        if batch_index >= self._captured_batch.shape[0]:
            return None
        return self._captured_batch[batch_index].clone()

get_all_logprobs()

Get all captured logprobs as a batch tensor.

Source code in genlm/backend/llm/vllm.py
def get_all_logprobs(self):
    """Get all captured logprobs as a batch tensor."""
    with self._lock:
        if self._captured_batch is None:
            return None
        return self._captured_batch.clone()

clear()

Clear captured logprobs.

Source code in genlm/backend/llm/vllm.py
def clear(self):
    """Clear captured logprobs."""
    with self._lock:
        self._captured_batch = None

AsyncVirtualLM

Bases: AsyncLM

Async language model using vLLM v1 with global logits processor.

This implementation uses vLLM v1's in-process mode with a global logits processor to efficiently capture full vocabulary log probabilities.

Source code in genlm/backend/llm/vllm.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
class AsyncVirtualLM(AsyncLM):  # pragma: no cover
    """Async language model using vLLM v1 with global logits processor.

    This implementation uses vLLM v1's in-process mode with a global
    logits processor to efficiently capture full vocabulary log probabilities.
    """

    default_params = {
        "max_tokens": 1,
        "n": 1,
        "detokenize": False,
        "stop": None,
        "ignore_eos": True,
    }

    def __init__(
        self,
        llm_engine,
        logprobs_capture,
        cache_size=0,
        cache_opts=None,
        batch_size=20,
        timeout=0.02,
    ):
        """Initialize an `AsyncVirtualLM` instance.

        Args:
            llm_engine (LLM): The vLLM engine instance.
            logprobs_capture (GlobalLogprobsCapture): The global logprobs capture processor.
            cache_size (int, optional): Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.
            cache_opts (dict, optional): Additional options to pass to the [`OutputCache`][genlm.backend.cache.OutputCache] constructor. Defaults to None (no extra options).
            batch_size (int, optional): Maximum queries to process in one batch during auto-batching. Defaults to 20.
            timeout (float, optional): Seconds to wait after the first queued query before processing the current batch. The batch also fires immediately when ``batch_size`` is reached. Defaults to 0.02.

        Note:
            The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine.
            ``batch_next_token_logprobs_sync`` bypasses this cache and always re-evaluates; the other three logprobs methods consult it.
        """
        self.llm_engine = llm_engine
        self.logprobs_capture = logprobs_capture
        self.tokenizer = llm_engine.get_tokenizer()
        self.cache = (
            OutputCache(maxsize=cache_size, **(cache_opts or {}))
            if cache_size > 0
            else None
        )
        self.lora_request = None
        self.lora_name_to_ids = {}

        self.queries = []
        self.batch_size = batch_size
        self.timeout = timeout
        self.timer = None

        self.sample_queries = []
        self.sample_timer = None

        super().__init__(tokenizer=self.tokenizer)

    @classmethod
    def from_name(cls, model_name, engine_opts=None, **kwargs):
        """Create a `AsyncVirtualLM` instance from a model name.

        Args:
            model_name (str): Name of the model to load.
            engine_opts (dict): Additional options to pass to the `LLM` engine.
            **kwargs: Additional arguments passed to `AsyncVirtualLM` constructor.

        Returns:
            (AsyncVirtualLM): An `AsyncVirtualLM` instance.
        """
        if not HAS_VLLM:
            raise ImportError(  # pragma: no cover
                "vLLM not available. Install vLLM or use AsyncTransformer instead."
            )

        engine_opts = {
            "enable_prefix_caching": True,
            "disable_log_stats": True,
            "gpu_memory_utilization": 0.9,
            **(engine_opts or {}),
        }

        llm = LLM(model=model_name, tokenizer=model_name, **engine_opts)

        logprobs_capture = GlobalLogprobsCapture()
        model_runner = cls._get_model_runner(llm)
        model_runner.input_batch.logitsprocs.argmax_invariant.append(
            logprobs_capture
        )

        return cls(llm, logprobs_capture, **kwargs)

    @staticmethod
    def _get_model_runner(llm):
        """Walk the vLLM v1 internals to reach the driver worker's model runner.

        This path is brittle against vLLM refactors, so it lives in one
        place and is reused by ``from_name`` (to inject the logits
        processor) and ``underlying_model``.
        """
        engine_core = llm.llm_engine.engine_core.engine_core
        return engine_core.model_executor.driver_worker.worker.model_runner

    @property
    def underlying_model(self):
        """Access the underlying model for advanced use cases."""
        return self._get_model_runner(self.llm_engine).model

    def clear_lora(self):
        """
        Disable any active LoRA adapter for the vLLM engine.
        """
        self.lora_request = None

    def add_new_lora(self, lora_path, lora_name="lora_1"):
        """Load a LoRA adapter into the base model by creating a unique id for it.

        Args:
            lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
            lora_name (str): Name to assign to the loaded adapter.

        Notes:
            This does not activate the adapter immediately. Call `set_lora()` to enable the adapter.
        """
        self.lora_name_to_ids[lora_name] = self.hash_to_int(lora_name)

    def hash_to_int(self, value):
        """Generates a deterministic unique id for a LoRA adapter from its name.

        Args:
            value (str): The name of the LoRA adapter to hash.

        Returns:
            An integer ID corresponding to the LoRA adapter, in the range [1, 2^31 - 1].
        """
        hash_bytes = hashlib.shake_128(value.encode("utf-8")).digest(4)
        return (int.from_bytes(hash_bytes, "big") % (2**31 - 2)) + 1

    def set_lora(self, lora_path, lora_name="lora_1"):
        """Configure a LoRA adapter request for the vLLM engine.

        Args:
            lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
            lora_name (str): Identifier name to associate with this LoRA adapter within vLLM.
            lora_id (int): Globally unique ID for the adapter.
        """
        if lora_name not in self.lora_name_to_ids:
            raise ValueError(
                f"A LoRA adapter named '{lora_name}' has not been loaded yet. Please call add_new_lora() first to load and name your LoRA adapters."
            )
        self.lora_request = LoRARequest(
            lora_name, self.lora_name_to_ids[lora_name], lora_path
        )

    async def next_token_logprobs(self, token_ids):
        """Request log probabilities of next token asynchronously with auto-batching.

        Concurrent calls to this method are automatically batched into a single
        ``LLM.generate()`` call for efficiency. Use with ``await``.

        Args:
            token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

        Returns:
            result (torch.Tensor): Normalized log probability tensor.
        """
        key = tuple(token_ids)

        if self.cache is not None and key in self.cache:
            return self.cache[key]

        future = asyncio.get_running_loop().create_future()
        self._add_query(token_ids, future)
        result = await future

        if self.cache is not None:
            self.cache[key] = result

        return result

    def _add_query(self, token_ids, future):
        """Add a query to be evaluated in the next batch.

        The timeout is measured from the *first* queued query, not the most
        recent one: we only arm the timer when the queue transitions from
        empty to non-empty. This prevents starvation when queries trickle in
        faster than ``self.timeout`` but never fill a batch.

        Args:
            token_ids (list[int]): Token IDs representing the query prompt.
            future (asyncio.Future): Future to store the result in.
        """
        self.queries.append((token_ids, future))

        if len(self.queries) >= self.batch_size:
            if self.timer:
                self.timer.cancel()
                self.timer = None
            self._batch_evaluate()
        elif self.timer is None:
            self.timer = asyncio.get_running_loop().call_later(
                self.timeout, self._batch_evaluate
            )

    def _batch_evaluate(self):
        """Process all queued queries in a single batched ``generate()`` call."""
        queries, self.queries = self.queries, []
        if not queries:
            return

        if self.timer:
            self.timer.cancel()
            self.timer = None

        if self.logprobs_capture is None:
            exc = RuntimeError("Cannot use model after cleanup() has been called")
            for _, future in queries:
                future.set_exception(exc)
            return

        # Deduplicate: group futures by identical prompts
        query_groups = defaultdict(list)
        for token_ids, future in queries:
            query_groups[tuple(token_ids)].append(future)

        unique_token_ids = list(query_groups.keys())

        self.logprobs_capture.clear()

        prompts = [
            TokensPrompt(prompt_token_ids=list(token_ids))
            for token_ids in unique_token_ids
        ]

        try:
            self.llm_engine.generate(
                prompts=prompts,
                sampling_params=SamplingParams(**self.default_params),
                lora_request=self.lora_request,
                use_tqdm=False,
            )

            all_logprobs = self.logprobs_capture.get_all_logprobs()
            assert all_logprobs is not None, "Logprobs should be captured"
            assert all_logprobs.shape[0] == len(unique_token_ids), (
                f"Expected {len(unique_token_ids)} logprobs, got {all_logprobs.shape[0]}"
            )

            for i, key in enumerate(unique_token_ids):
                logprobs = all_logprobs[i]
                futures = query_groups[key]
                if len(futures) == 1:
                    futures[0].set_result(logprobs)
                else:
                    for future in futures:
                        future.set_result(logprobs.clone())
        except Exception as exc:
            for futures in query_groups.values():
                for future in futures:
                    if not future.done():
                        future.set_exception(exc)

    def reset_async_queries(self):
        """Clear any pending queries from the queue.

        Use this method when an exception prevented an inference algorithm
        from executing to completion.
        """
        self.queries = []
        if self.timer:
            self.timer.cancel()
            self.timer = None

        self.sample_queries = []
        if self.sample_timer:
            self.sample_timer.cancel()
            self.sample_timer = None

    def next_token_logprobs_sync(self, token_ids):
        """Request log probabilities of next token synchronously.

        Does not support auto-batching. For batched sync calls, use
        ``batch_next_token_logprobs_sync`` instead.

        Args:
            token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        key = tuple(token_ids)

        if self.cache is not None and key in self.cache:
            return self.cache[key]

        if self.logprobs_capture is None:
            raise RuntimeError("Cannot use model after cleanup() has been called")

        self.logprobs_capture.clear()

        self.llm_engine.generate(
            prompts=TokensPrompt(prompt_token_ids=list(token_ids)),
            sampling_params=SamplingParams(**self.default_params),
            lora_request=self.lora_request,
            use_tqdm=False,
        )

        result = self.logprobs_capture.get_logprobs(batch_index=0)
        assert result is not None, "Logprobs should be captured by global processor"

        if self.cache is not None:
            self.cache[key] = result

        return result

    def batch_next_token_logprobs_sync(self, token_ids_list):
        """
        Request log probabilities of next tokens in a batch synchronously.

        Args:
            token_ids_list (list[list[int]]): A list of token ID lists, each representing a prompt to the language model.

        Returns:
            (torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list.

        Note:
            This method does not consult the output cache (unlike the async batch path,
            which delegates to the cached ``next_token_logprobs``). Every prompt is
            re-evaluated.
        """
        if self.logprobs_capture is None:
            raise RuntimeError("Cannot use model after cleanup() has been called")
        # Clear any stale captured logprobs
        self.logprobs_capture.clear()

        # Create prompts for batch
        prompts = [
            TokensPrompt(prompt_token_ids=list(token_ids))
            for token_ids in token_ids_list
        ]

        # Generate one token for each prompt
        self.llm_engine.generate(
            prompts=prompts,
            sampling_params=SamplingParams(**self.default_params),
            lora_request=self.lora_request,
            use_tqdm=False,
        )

        # Get all captured logprobs at once (optimized - single clone)
        all_logprobs = self.logprobs_capture.get_all_logprobs()
        assert all_logprobs is not None, "Logprobs should be captured"
        assert all_logprobs.shape[0] == len(token_ids_list), (
            f"Expected {len(token_ids_list)} logprobs, got {all_logprobs.shape[0]}"
        )

        return all_logprobs

    def clear_cache(self):
        """Clear output cache."""
        if self.cache:
            self.cache.clear()

    def cleanup(self):
        """Explicitly clean up GPU resources. Call this when done with the model."""
        self._cleanup_engine()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.cleanup()
        return False

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        self.cleanup()
        return False

    def __del__(self):
        """Clean up resources on deletion."""
        self._cleanup_engine()

    def _cleanup_engine(self):
        """Clean up the vLLM engine and associated resources.

        This is invoked from both :meth:`cleanup` (explicit, during normal
        program flow) and :meth:`__del__` (implicit, possibly at
        interpreter shutdown). The narrow exception classes below cover
        the races and idempotency issues we know about:

        * ``ImportError`` / ``AttributeError`` arise when ``__del__`` runs
          after ``sys.meta_path`` is already torn down during interpreter
          shutdown.
        * ``AssertionError`` is raised by vLLM's
          ``destroy_distributed_environment`` if it's called twice.
        * ``RuntimeError`` can surface from CUDA when the driver is
          already being torn down.

        Anything else is re-raised so real bugs are not swallowed.
        """
        try:
            # ``import gc`` can itself raise ImportError when ``__del__`` is
            # invoked after ``sys.meta_path`` has been torn down at
            # interpreter shutdown, so it lives inside the try block.
            import gc

            # Clear our references
            if hasattr(self, "logprobs_capture"):
                if self.logprobs_capture is not None:
                    self.logprobs_capture.clear()
                self.logprobs_capture = None

            # Delete the engine to free GPU memory
            if hasattr(self, "llm_engine") and self.llm_engine is not None:
                del self.llm_engine
                self.llm_engine = None

            # Force garbage collection
            gc.collect()

            # Clear CUDA cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()

            # Clean up distributed state
            destroy_model_parallel()
            destroy_distributed_environment()
        except (
            ImportError,
            AttributeError,
            AssertionError,
            RuntimeError,
        ) as e:
            # Best-effort log; during interpreter shutdown logging itself
            # may already be torn down, in which case silently drop.
            with contextlib.suppress(Exception):
                logging.getLogger(__name__).debug(
                    "AsyncVirtualLM cleanup raised %s: %s",
                    type(e).__name__,
                    e,
                )

    def _add_sample_query(self, prompt_token_ids, sampling_params, future):
        """Enqueue a ``sample()`` request; mirrors ``_add_query`` for the logprobs path."""
        self.sample_queries.append((prompt_token_ids, sampling_params, future))
        if len(self.sample_queries) >= self.batch_size:
            if self.sample_timer:
                self.sample_timer.cancel()
                self.sample_timer = None
            self._batch_sample_evaluate()
        elif self.sample_timer is None:
            self.sample_timer = asyncio.get_running_loop().call_later(
                self.timeout, self._batch_sample_evaluate
            )

    def _batch_sample_evaluate(self):
        """Dispatch queued ``sample()`` requests in one batched ``generate()`` call."""
        queries, self.sample_queries = self.sample_queries, []
        if not queries:
            return
        if self.sample_timer:
            self.sample_timer.cancel()
            self.sample_timer = None
        if self.logprobs_capture is None:
            exc = RuntimeError("Cannot use model after cleanup() has been called")
            for _, _, future in queries:
                future.set_exception(exc)
            return
        try:
            outputs = self.llm_engine.generate(
                prompts=[TokensPrompt(prompt_token_ids=t) for t, _, _ in queries],
                sampling_params=[sp for _, sp, _ in queries],
                lora_request=self.lora_request,
                use_tqdm=False,
            )
            assert len(outputs) == len(queries)
            for output, (_, _, future) in zip(outputs, queries):
                future.set_result(list(output.outputs[0].token_ids))
        except Exception as exc:
            for _, _, future in queries:
                if not future.done():
                    future.set_exception(exc)

    async def sample(
        self,
        prompt_token_ids,
        max_tokens,
        eos_token_ids,
        temperature=1.0,
        seed=None,
    ):
        """Sample from the language model.

        Concurrent calls are auto-batched into a single ``LLM.generate()``
        so vLLM continuous-batches the decode steps. Use with ``await``.

        Args:
            prompt_token_ids (list[int]): The token IDs of the prompt.
            eos_token_ids (list[int]): The token IDs of the end-of-sequence tokens.
            temperature (float, optional): The temperature to use to rescale the logits. Defaults to 1.0.
            max_tokens (int): The maximum number of tokens to generate.
            seed (int, optional): The seed for the random number generator. Defaults to None.

        Returns:
            (list[int]): The sampled token IDs.
        """
        future = asyncio.get_running_loop().create_future()
        self._add_sample_query(
            list(prompt_token_ids),
            SamplingParams(
                n=1,
                max_tokens=max_tokens,
                temperature=temperature,
                seed=seed,
                stop=[self.byte_vocab[i].decode() for i in eos_token_ids],
            ),
            future,
        )
        token_ids = await future
        if token_ids and token_ids[-1] in eos_token_ids:
            token_ids = token_ids[:-1]
        return token_ids

__init__(llm_engine, logprobs_capture, cache_size=0, cache_opts=None, batch_size=20, timeout=0.02)

Initialize an AsyncVirtualLM instance.

Parameters:

Name Type Description Default
llm_engine LLM

The vLLM engine instance.

required
logprobs_capture GlobalLogprobsCapture

The global logprobs capture processor.

required
cache_size int

Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.

0
cache_opts dict

Additional options to pass to the OutputCache constructor. Defaults to None (no extra options).

None
batch_size int

Maximum queries to process in one batch during auto-batching. Defaults to 20.

20
timeout float

Seconds to wait after the first queued query before processing the current batch. The batch also fires immediately when batch_size is reached. Defaults to 0.02.

0.02
Note

The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine. batch_next_token_logprobs_sync bypasses this cache and always re-evaluates; the other three logprobs methods consult it.

Source code in genlm/backend/llm/vllm.py
def __init__(
    self,
    llm_engine,
    logprobs_capture,
    cache_size=0,
    cache_opts=None,
    batch_size=20,
    timeout=0.02,
):
    """Initialize an `AsyncVirtualLM` instance.

    Args:
        llm_engine (LLM): The vLLM engine instance.
        logprobs_capture (GlobalLogprobsCapture): The global logprobs capture processor.
        cache_size (int, optional): Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.
        cache_opts (dict, optional): Additional options to pass to the [`OutputCache`][genlm.backend.cache.OutputCache] constructor. Defaults to None (no extra options).
        batch_size (int, optional): Maximum queries to process in one batch during auto-batching. Defaults to 20.
        timeout (float, optional): Seconds to wait after the first queued query before processing the current batch. The batch also fires immediately when ``batch_size`` is reached. Defaults to 0.02.

    Note:
        The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine.
        ``batch_next_token_logprobs_sync`` bypasses this cache and always re-evaluates; the other three logprobs methods consult it.
    """
    self.llm_engine = llm_engine
    self.logprobs_capture = logprobs_capture
    self.tokenizer = llm_engine.get_tokenizer()
    self.cache = (
        OutputCache(maxsize=cache_size, **(cache_opts or {}))
        if cache_size > 0
        else None
    )
    self.lora_request = None
    self.lora_name_to_ids = {}

    self.queries = []
    self.batch_size = batch_size
    self.timeout = timeout
    self.timer = None

    self.sample_queries = []
    self.sample_timer = None

    super().__init__(tokenizer=self.tokenizer)

from_name(model_name, engine_opts=None, **kwargs) classmethod

Create a AsyncVirtualLM instance from a model name.

Parameters:

Name Type Description Default
model_name str

Name of the model to load.

required
engine_opts dict

Additional options to pass to the LLM engine.

None
**kwargs

Additional arguments passed to AsyncVirtualLM constructor.

{}

Returns:

Type Description
AsyncVirtualLM

An AsyncVirtualLM instance.

Source code in genlm/backend/llm/vllm.py
@classmethod
def from_name(cls, model_name, engine_opts=None, **kwargs):
    """Create a `AsyncVirtualLM` instance from a model name.

    Args:
        model_name (str): Name of the model to load.
        engine_opts (dict): Additional options to pass to the `LLM` engine.
        **kwargs: Additional arguments passed to `AsyncVirtualLM` constructor.

    Returns:
        (AsyncVirtualLM): An `AsyncVirtualLM` instance.
    """
    if not HAS_VLLM:
        raise ImportError(  # pragma: no cover
            "vLLM not available. Install vLLM or use AsyncTransformer instead."
        )

    engine_opts = {
        "enable_prefix_caching": True,
        "disable_log_stats": True,
        "gpu_memory_utilization": 0.9,
        **(engine_opts or {}),
    }

    llm = LLM(model=model_name, tokenizer=model_name, **engine_opts)

    logprobs_capture = GlobalLogprobsCapture()
    model_runner = cls._get_model_runner(llm)
    model_runner.input_batch.logitsprocs.argmax_invariant.append(
        logprobs_capture
    )

    return cls(llm, logprobs_capture, **kwargs)

underlying_model property

Access the underlying model for advanced use cases.

clear_lora()

Disable any active LoRA adapter for the vLLM engine.

Source code in genlm/backend/llm/vllm.py
def clear_lora(self):
    """
    Disable any active LoRA adapter for the vLLM engine.
    """
    self.lora_request = None

add_new_lora(lora_path, lora_name='lora_1')

Load a LoRA adapter into the base model by creating a unique id for it.

Parameters:

Name Type Description Default
lora_path str

Path to the adapter weights directory or identifier in HuggingFace's model hub.

required
lora_name str

Name to assign to the loaded adapter.

'lora_1'
Notes

This does not activate the adapter immediately. Call set_lora() to enable the adapter.

Source code in genlm/backend/llm/vllm.py
def add_new_lora(self, lora_path, lora_name="lora_1"):
    """Load a LoRA adapter into the base model by creating a unique id for it.

    Args:
        lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
        lora_name (str): Name to assign to the loaded adapter.

    Notes:
        This does not activate the adapter immediately. Call `set_lora()` to enable the adapter.
    """
    self.lora_name_to_ids[lora_name] = self.hash_to_int(lora_name)

hash_to_int(value)

Generates a deterministic unique id for a LoRA adapter from its name.

Parameters:

Name Type Description Default
value str

The name of the LoRA adapter to hash.

required

Returns:

Type Description

An integer ID corresponding to the LoRA adapter, in the range [1, 2^31 - 1].

Source code in genlm/backend/llm/vllm.py
def hash_to_int(self, value):
    """Generates a deterministic unique id for a LoRA adapter from its name.

    Args:
        value (str): The name of the LoRA adapter to hash.

    Returns:
        An integer ID corresponding to the LoRA adapter, in the range [1, 2^31 - 1].
    """
    hash_bytes = hashlib.shake_128(value.encode("utf-8")).digest(4)
    return (int.from_bytes(hash_bytes, "big") % (2**31 - 2)) + 1

set_lora(lora_path, lora_name='lora_1')

Configure a LoRA adapter request for the vLLM engine.

Parameters:

Name Type Description Default
lora_path str

Path to the adapter weights directory or identifier in HuggingFace's model hub.

required
lora_name str

Identifier name to associate with this LoRA adapter within vLLM.

'lora_1'
lora_id int

Globally unique ID for the adapter.

required
Source code in genlm/backend/llm/vllm.py
def set_lora(self, lora_path, lora_name="lora_1"):
    """Configure a LoRA adapter request for the vLLM engine.

    Args:
        lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
        lora_name (str): Identifier name to associate with this LoRA adapter within vLLM.
        lora_id (int): Globally unique ID for the adapter.
    """
    if lora_name not in self.lora_name_to_ids:
        raise ValueError(
            f"A LoRA adapter named '{lora_name}' has not been loaded yet. Please call add_new_lora() first to load and name your LoRA adapters."
        )
    self.lora_request = LoRARequest(
        lora_name, self.lora_name_to_ids[lora_name], lora_path
    )

next_token_logprobs(token_ids) async

Request log probabilities of next token asynchronously with auto-batching.

Concurrent calls to this method are automatically batched into a single LLM.generate() call for efficiency. Use with await.

Parameters:

Name Type Description Default
token_ids list[int]

A list of token IDs, representing a prompt to the language model.

required

Returns:

Name Type Description
result Tensor

Normalized log probability tensor.

Source code in genlm/backend/llm/vllm.py
async def next_token_logprobs(self, token_ids):
    """Request log probabilities of next token asynchronously with auto-batching.

    Concurrent calls to this method are automatically batched into a single
    ``LLM.generate()`` call for efficiency. Use with ``await``.

    Args:
        token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

    Returns:
        result (torch.Tensor): Normalized log probability tensor.
    """
    key = tuple(token_ids)

    if self.cache is not None and key in self.cache:
        return self.cache[key]

    future = asyncio.get_running_loop().create_future()
    self._add_query(token_ids, future)
    result = await future

    if self.cache is not None:
        self.cache[key] = result

    return result

reset_async_queries()

Clear any pending queries from the queue.

Use this method when an exception prevented an inference algorithm from executing to completion.

Source code in genlm/backend/llm/vllm.py
def reset_async_queries(self):
    """Clear any pending queries from the queue.

    Use this method when an exception prevented an inference algorithm
    from executing to completion.
    """
    self.queries = []
    if self.timer:
        self.timer.cancel()
        self.timer = None

    self.sample_queries = []
    if self.sample_timer:
        self.sample_timer.cancel()
        self.sample_timer = None

next_token_logprobs_sync(token_ids)

Request log probabilities of next token synchronously.

Does not support auto-batching. For batched sync calls, use batch_next_token_logprobs_sync instead.

Parameters:

Name Type Description Default
token_ids list[int]

A list of token IDs, representing a prompt to the language model.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm/backend/llm/vllm.py
def next_token_logprobs_sync(self, token_ids):
    """Request log probabilities of next token synchronously.

    Does not support auto-batching. For batched sync calls, use
    ``batch_next_token_logprobs_sync`` instead.

    Args:
        token_ids (list[int]): A list of token IDs, representing a prompt to the language model.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    key = tuple(token_ids)

    if self.cache is not None and key in self.cache:
        return self.cache[key]

    if self.logprobs_capture is None:
        raise RuntimeError("Cannot use model after cleanup() has been called")

    self.logprobs_capture.clear()

    self.llm_engine.generate(
        prompts=TokensPrompt(prompt_token_ids=list(token_ids)),
        sampling_params=SamplingParams(**self.default_params),
        lora_request=self.lora_request,
        use_tqdm=False,
    )

    result = self.logprobs_capture.get_logprobs(batch_index=0)
    assert result is not None, "Logprobs should be captured by global processor"

    if self.cache is not None:
        self.cache[key] = result

    return result

batch_next_token_logprobs_sync(token_ids_list)

Request log probabilities of next tokens in a batch synchronously.

Parameters:

Name Type Description Default
token_ids_list list[list[int]]

A list of token ID lists, each representing a prompt to the language model.

required

Returns:

Type Description
Tensor

A tensor of normalized log probability tensors, one for each prompt in the input list.

Note

This method does not consult the output cache (unlike the async batch path, which delegates to the cached next_token_logprobs). Every prompt is re-evaluated.

Source code in genlm/backend/llm/vllm.py
def batch_next_token_logprobs_sync(self, token_ids_list):
    """
    Request log probabilities of next tokens in a batch synchronously.

    Args:
        token_ids_list (list[list[int]]): A list of token ID lists, each representing a prompt to the language model.

    Returns:
        (torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list.

    Note:
        This method does not consult the output cache (unlike the async batch path,
        which delegates to the cached ``next_token_logprobs``). Every prompt is
        re-evaluated.
    """
    if self.logprobs_capture is None:
        raise RuntimeError("Cannot use model after cleanup() has been called")
    # Clear any stale captured logprobs
    self.logprobs_capture.clear()

    # Create prompts for batch
    prompts = [
        TokensPrompt(prompt_token_ids=list(token_ids))
        for token_ids in token_ids_list
    ]

    # Generate one token for each prompt
    self.llm_engine.generate(
        prompts=prompts,
        sampling_params=SamplingParams(**self.default_params),
        lora_request=self.lora_request,
        use_tqdm=False,
    )

    # Get all captured logprobs at once (optimized - single clone)
    all_logprobs = self.logprobs_capture.get_all_logprobs()
    assert all_logprobs is not None, "Logprobs should be captured"
    assert all_logprobs.shape[0] == len(token_ids_list), (
        f"Expected {len(token_ids_list)} logprobs, got {all_logprobs.shape[0]}"
    )

    return all_logprobs

clear_cache()

Clear output cache.

Source code in genlm/backend/llm/vllm.py
def clear_cache(self):
    """Clear output cache."""
    if self.cache:
        self.cache.clear()

cleanup()

Explicitly clean up GPU resources. Call this when done with the model.

Source code in genlm/backend/llm/vllm.py
def cleanup(self):
    """Explicitly clean up GPU resources. Call this when done with the model."""
    self._cleanup_engine()

__del__()

Clean up resources on deletion.

Source code in genlm/backend/llm/vllm.py
def __del__(self):
    """Clean up resources on deletion."""
    self._cleanup_engine()

sample(prompt_token_ids, max_tokens, eos_token_ids, temperature=1.0, seed=None) async

Sample from the language model.

Concurrent calls are auto-batched into a single LLM.generate() so vLLM continuous-batches the decode steps. Use with await.

Parameters:

Name Type Description Default
prompt_token_ids list[int]

The token IDs of the prompt.

required
eos_token_ids list[int]

The token IDs of the end-of-sequence tokens.

required
temperature float

The temperature to use to rescale the logits. Defaults to 1.0.

1.0
max_tokens int

The maximum number of tokens to generate.

required
seed int

The seed for the random number generator. Defaults to None.

None

Returns:

Type Description
list[int]

The sampled token IDs.

Source code in genlm/backend/llm/vllm.py
async def sample(
    self,
    prompt_token_ids,
    max_tokens,
    eos_token_ids,
    temperature=1.0,
    seed=None,
):
    """Sample from the language model.

    Concurrent calls are auto-batched into a single ``LLM.generate()``
    so vLLM continuous-batches the decode steps. Use with ``await``.

    Args:
        prompt_token_ids (list[int]): The token IDs of the prompt.
        eos_token_ids (list[int]): The token IDs of the end-of-sequence tokens.
        temperature (float, optional): The temperature to use to rescale the logits. Defaults to 1.0.
        max_tokens (int): The maximum number of tokens to generate.
        seed (int, optional): The seed for the random number generator. Defaults to None.

    Returns:
        (list[int]): The sampled token IDs.
    """
    future = asyncio.get_running_loop().create_future()
    self._add_sample_query(
        list(prompt_token_ids),
        SamplingParams(
            n=1,
            max_tokens=max_tokens,
            temperature=temperature,
            seed=seed,
            stop=[self.byte_vocab[i].decode() for i in eos_token_ids],
        ),
        future,
    )
    token_ids = await future
    if token_ids and token_ids[-1] in eos_token_ids:
        token_ids = token_ids[:-1]
    return token_ids