Skip to content

bytellm

ByteLLM

Bases: Potential

A potential representing a language model operating at the byte level using beam search.

ByteLLM wraps a language model and uses beam search to compute log probabilities over byte sequences. This enables constrained generation at the byte level while maintaining coherent token-level probabilities through adaptive token healing.

Parameters:

Name Type Description Default
llm Any

The language model to use (from genlm.backend).

required
beam_params BeamParams

Configuration for beam search, including beam width K, eos_tokens, and healing parameters (heal, heal_max_backoff, heal_max_splits).

required
cache_size int

Maximum number of beam states to cache. Defaults to 1024.

1024
Example
from genlm.bytes import BeamParams
from genlm.control import ByteLLM

beam_params = BeamParams(K=5, eos_tokens={b"<|endoftext|>"}, heal=True)
async with ByteLLM.from_name("gpt2", beam_params) as byte_llm:
    byte_llm.set_prompt_from_str("Hello")
    logp = await byte_llm.prefix([b" ", b"w", b"o", b"r", b"l", b"d"])
Source code in genlm/control/potential/built_in/bytellm.py
class ByteLLM(Potential):
    """A potential representing a language model operating at the byte level using beam search.

    `ByteLLM` wraps a language model and uses beam search to compute log probabilities
    over byte sequences. This enables constrained generation at the byte level while
    maintaining coherent token-level probabilities through adaptive token healing.

    Args:
        llm: The language model to use (from `genlm.backend`).
        beam_params (BeamParams): Configuration for beam search, including beam width `K`,
            `eos_tokens`, and healing parameters (`heal`, `heal_max_backoff`, `heal_max_splits`).
        cache_size (int): Maximum number of beam states to cache. Defaults to 1024.

    Example:
        ```python
        from genlm.bytes import BeamParams
        from genlm.control import ByteLLM

        beam_params = BeamParams(K=5, eos_tokens={b"<|endoftext|>"}, heal=True)
        async with ByteLLM.from_name("gpt2", beam_params) as byte_llm:
            byte_llm.set_prompt_from_str("Hello")
            logp = await byte_llm.prefix([b" ", b"w", b"o", b"r", b"l", b"d"])
        ```
    """

    def __init__(self, llm: Any, beam_params: BeamParams, cache_size: int = 1024):
        self.llm = llm
        self.beam_params = beam_params
        self.cache_size = cache_size
        vocab = [i.to_bytes(1, "big") for i in range(256)]
        super().__init__(vocabulary=vocab)
        # LRU cache of ByteBeamState keyed by full context bytes (prompt + context)
        self._beam_cache: OrderedDict[bytes, ByteBeamState] = OrderedDict()
        self._initial_beam = None
        self.prompt_bytes = b""
        # Fast path: cache last accessed beam for sequential access
        self._last_context = None
        self._last_beam = None

    @classmethod
    def from_name(
        cls,
        name,
        beam_params: BeamParams,
        backend=None,
        cache_size: int = 1024,
        **kwargs,
    ):
        backend = backend or ("vllm" if torch.cuda.is_available() else "hf")
        llm = load_model_by_name(name, backend=backend, **kwargs)
        return cls(llm, beam_params, cache_size=cache_size)

    def set_prompt_from_str(self, prompt_str: str):
        new_prompt_bytes = prompt_str.encode("utf-8")
        if new_prompt_bytes != self.prompt_bytes:
            self.prompt_bytes = new_prompt_bytes
            self._beam_cache.clear()
            self._initial_beam = None
            self._last_context = None
            self._last_beam = None

    async def _get_or_create_beam_for_context(self, context):
        context_bytes = b"".join(context)
        full_context_bytes = self.prompt_bytes + context_bytes

        # Fast path: exact cache hit
        if full_context_bytes in self._beam_cache:
            self._beam_cache.move_to_end(full_context_bytes)
            beam = self._beam_cache[full_context_bytes]
            self._last_context = full_context_bytes
            self._last_beam = beam
            return beam

        # Fast path: sequential access from last beam
        if (
            self._last_context is not None
            and full_context_bytes.startswith(self._last_context)
            and len(full_context_bytes) > len(self._last_context)
        ):
            best_prefix_bytes = self._last_context
            best_beam = self._last_beam
        else:
            # Search cache for longest prefix match
            best_prefix_bytes = b""
            best_beam = None
            for cached_prefix_bytes, cached_beam in self._beam_cache.items():
                if full_context_bytes.startswith(cached_prefix_bytes) and len(
                    cached_prefix_bytes
                ) > len(best_prefix_bytes):
                    best_prefix_bytes = cached_prefix_bytes
                    best_beam = cached_beam

            if best_beam is None:
                if self._initial_beam is None:
                    self._initial_beam = await ByteBeamState.initial(
                        self.llm, self.beam_params
                    )
                    if self.prompt_bytes:
                        self._initial_beam = await self._initial_beam.prefill(
                            self.prompt_bytes
                        )
                        self._cache_put(self.prompt_bytes, self._initial_beam)
                best_beam = self._initial_beam
                best_prefix_bytes = (
                    self.prompt_bytes
                    if full_context_bytes.startswith(self.prompt_bytes)
                    else b""
                )

        # Advance beam byte-by-byte
        remaining_bytes = full_context_bytes[len(best_prefix_bytes) :]
        current_beam = best_beam
        current_prefix_bytes = best_prefix_bytes

        for i, byte_val in enumerate(remaining_bytes):
            current_beam = current_beam.prune()
            current_beam = await (current_beam << byte_val)
            current_prefix_bytes += remaining_bytes[i : i + 1]

            if len(current_beam) == 0:
                raise ValueError(
                    f"Beam became empty at byte {byte_val} ({chr(byte_val) if 32 <= byte_val < 127 else f'0x{byte_val:02x}'}). "
                    f"Context so far: {current_prefix_bytes!r}. "
                    f"Consider enabling healing or increasing beam width K."
                )

            self._cache_put(current_prefix_bytes, current_beam)

        # Update last beam for fast sequential access
        self._last_context = full_context_bytes
        self._last_beam = current_beam

        return current_beam

    def _cache_put(self, key: bytes, beam: ByteBeamState):
        self._beam_cache[key] = beam
        self._beam_cache.move_to_end(key)
        while len(self._beam_cache) > self.cache_size:
            self._beam_cache.popitem(last=False)

    async def prefix(self, context):
        # Treat empty context as neutral (log 1 = 0), matching PromptedLLM semantics.
        # The prompt, if set, is incorporated into next-token distributions via the cached beam,
        # but does not contribute to the prefix weight of the empty context.
        if not context:
            return 0.0
        beam = await self._get_or_create_beam_for_context(context)
        base = self._initial_beam.logZ if self._initial_beam is not None else 0.0
        return beam.logZ - base

    async def complete(self, context):
        beam = await self._get_or_create_beam_for_context(context)
        logp_next = await beam.logp_next()
        # Assume logp_next.ps contains log-probs for 256 byte values plus EOS at the end.
        eos_logp = logp_next.ps[-1]
        base = self._initial_beam.logZ if self._initial_beam is not None else 0.0
        return (beam.logZ - base) + eos_logp

    async def logw_next(self, context):
        """Efficient next-token weights using the cached beam state.

        Uses the beam's next-token distribution directly instead of the
        default (slower) fallback that recomputes scores for each token.
        """
        beam = await self._get_or_create_beam_for_context(context)
        logp_next = await beam.logp_next()

        # Build weights over vocab_eos (256 bytes + EOS at the end)
        ps = np.asarray(logp_next.ps)
        logws = self.alloc_logws()
        v = len(self.vocab)
        logws[:v] = ps[:v]
        logws[-1] = ps[-1]
        return self.make_lazy_weights(logws)

    async def cleanup(self):
        """Cleans up resources used by the beam states.

        This method is called automatically when using ByteLLM as an async context manager.
        If not using a context manager, you should call this method manually when done.
        """
        if self._initial_beam:
            await self._initial_beam.cleanup()
        for beam in self._beam_cache.values():
            await beam.cleanup()
        self._beam_cache.clear()
        self._last_context = None
        self._last_beam = None

    async def __aenter__(self):
        """Async context manager entry."""
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Async context manager exit - ensures cleanup is called."""
        await self.cleanup()
        return False

logw_next(context) async

Efficient next-token weights using the cached beam state.

Uses the beam's next-token distribution directly instead of the default (slower) fallback that recomputes scores for each token.

Source code in genlm/control/potential/built_in/bytellm.py
async def logw_next(self, context):
    """Efficient next-token weights using the cached beam state.

    Uses the beam's next-token distribution directly instead of the
    default (slower) fallback that recomputes scores for each token.
    """
    beam = await self._get_or_create_beam_for_context(context)
    logp_next = await beam.logp_next()

    # Build weights over vocab_eos (256 bytes + EOS at the end)
    ps = np.asarray(logp_next.ps)
    logws = self.alloc_logws()
    v = len(self.vocab)
    logws[:v] = ps[:v]
    logws[-1] = ps[-1]
    return self.make_lazy_weights(logws)

cleanup() async

Cleans up resources used by the beam states.

This method is called automatically when using ByteLLM as an async context manager. If not using a context manager, you should call this method manually when done.

Source code in genlm/control/potential/built_in/bytellm.py
async def cleanup(self):
    """Cleans up resources used by the beam states.

    This method is called automatically when using ByteLLM as an async context manager.
    If not using a context manager, you should call this method manually when done.
    """
    if self._initial_beam:
        await self._initial_beam.cleanup()
    for beam in self._beam_cache.values():
        await beam.cleanup()
    self._beam_cache.clear()
    self._last_context = None
    self._last_beam = None

__aenter__() async

Async context manager entry.

Source code in genlm/control/potential/built_in/bytellm.py
async def __aenter__(self):
    """Async context manager entry."""
    return self

__aexit__(exc_type, exc_val, exc_tb) async

Async context manager exit - ensures cleanup is called.

Source code in genlm/control/potential/built_in/bytellm.py
async def __aexit__(self, exc_type, exc_val, exc_tb):
    """Async context manager exit - ensures cleanup is called."""
    await self.cleanup()
    return False