Skip to content

byte_lm

genlm.bytes.byte_lm

ByteBeamState

Bases: StatefulByteLM

Represents the state of the beam during byte-level language modeling.

Tracks multiple candidate states and their probabilities, pruning low-probability candidates.

Parameters:

Name Type Description Default
states list[LazyTrieState]

List of candidate states to track

required
params BeamParams

Parameters controlling beam search behavior

required
Source code in genlm/bytes/byte_lm/beam.py
class ByteBeamState(StatefulByteLM):
    """Represents the state of the beam during byte-level language modeling.

    Tracks multiple candidate states and their probabilities, pruning low-probability
    candidates.

    Args:
        states (list[LazyTrieState]): List of candidate states to track
        params (BeamParams): Parameters controlling beam search behavior
    """

    def __init__(self, states, params):
        self.states = sorted(states, key=lambda b: -b.weight)
        self.params = params

    @classmethod
    async def initial(cls, llm, params, trie_opts=None):
        """Creates initial beam state.

        Args:
            llm (StatefulTokenizedLM): Token-level language model to use.
            params (BeamParams): Beam search parameters.
            trie_opts (dict, optional): Additional keyword arguments passed to
                AsyncTokenByteTrie.from_vocab. For example, {"max_batch_size": 100}.

        Returns:
            (ByteBeamState): Initial beam state.
        """
        # Handle EOS tokens
        trie_opts = trie_opts or {}
        trie_opts["eos_tokens"] = params.eos_tokens

        async_trie = AsyncTokenByteTrie.from_vocab(
            get_byte_vocab(llm.tokenizer), **trie_opts
        )
        state = LazyTrieState.initial(llm, async_trie, mode=TrieMode.WITH_EOS)
        return cls([await state.materialize()], params)

    def __iter__(self):
        return iter(self.states)

    def __len__(self):
        return len(self.states)

    @cached_property
    def logZ(self):
        """Estimate of the partition function (sum of weights) for current beam.
        This is the estimate of the prefix probability of the bytes consumed so far.
        """
        return logsumexp([state.weight for state in self])

    async def __lshift__(self, a):
        """Advances the beam state with a new byte.

        Args:
            a (int): Byte to add to states.

        Returns:
            (ByteBeamState): New beam state after processing the byte.
        """
        new_states = []
        for state in self:
            if new_state := state << a:
                new_states.append(new_state)

        logZ = logsumexp([s.weight for s in new_states]) if new_states else -np.inf
        for state in await self.extend(logZ):
            if new_state := state << a:
                new_states.append(new_state)

        new_state = ByteBeamState(new_states, self.params)

        # If advancing would empty the beam, do adaptive healing if enabled
        if self.params.heal and len(new_state) == 0:
            healed = await self._adaptive_heal(a)
            if healed is not None:
                if self.params.verbose:
                    print("[heal] Applied adaptive token healing")
                return healed

        if self.params.verbose:
            print()
            print(new_state)

        return new_state

    async def logp_next(self):
        """Computes log probabilities for the next byte across all beam candidates.

        Returns:
            (LazyByteProbs): Log probabilities for next possible bytes.
        """
        assert len(self) > 0, "Beam is empty"

        logqs = []
        for state in self:
            logqs.append(state.logp_next.ps + state.weight)

        for state in await self.extend(self.logZ):
            logqs.append(state.logp_next.ps + state.weight)

        logqs = np.stack(logqs, axis=0)  # shape: (num_states, array_size)
        # mask EOT positions of non-extended (EOT is at index 256)
        logqs[: len(self), -2] = -np.inf
        logps = scipy_logsumexp(logqs, axis=0)

        return LazyByteProbs(logps - logsumexp(logps))

    async def extend(self, logZ):
        """Attempts to advance each candidate in the beam by a token (EOT).

        For each candididate with EOT available, this ends the current token and
        starts a new one in preparation for the next byte.

        Args:
            logZ (float): Current estimated of the partition function for pruning

        Returns:
            (list[LazyTrieState]): New candidate states after extension
        """
        extends = []
        for state in self:
            if new_state := state.extend():
                logZ = np.logaddexp(logZ, new_state.weight)
                extends.append(new_state)

        coros = []
        for state in extends:
            if state.weight - logZ > self.params.log_prune_threshold:
                coros.append(state.materialize())

        return await asyncio.gather(*coros)

    def prune(self):
        """Prunes beam to maintain beam width and probability threshold.

        Returns:
            (ByteBeamState): New state with pruned candidates.
        """
        new_states = [
            state
            for state in self
            if state.weight - self.logZ > self.params.log_prune_threshold
        ][: self.params.K]
        return ByteBeamState(new_states, self.params)

    def __repr__(self):
        desc = colors.bold % f"Z: {self.logZ}\n" + colors.bold % "Candidates:\n"
        for state in self:
            P = np.exp(state.weight - self.logZ)
            color = colors.green if P > self.params.prune_threshold else colors.red
            desc += f"({color % f'{P:.4f}'}) {repr(state)}\n"
        return desc

    def with_mode(self, mode):
        """Create a new beam state with specified trie mode.

        Args:
            mode (TrieMode): Trie mode for the new beam state

        Returns:
            (ByteBeamState): New beam state with updated mode
        """
        return ByteBeamState(
            states=[state.with_mode(mode) for state in self.states],
            params=self.params,
        )

    async def prefill(self, bs):
        """Prefill the beam on a sequence of bytes.

        During prefilling, EOS tokens are treated as normal tokens and don't cause termination.

        Args:
            bs (bytes): Byte sequence to prefill on

        Returns:
            (ByteBeamState): New beam state after prefilling
        """
        # Create no_eos beam for prefill (EOS tokens treated as normal)
        no_eos_beam = self.with_mode(TrieMode.WITHOUT_EOS)

        # Do prefill operations on no_eos beam
        for b in bs:
            no_eos_beam = await (no_eos_beam.prune() << b)

        # Return as with_eos beam (EOS tokens get special handling after prefill)
        return no_eos_beam.with_mode(TrieMode.WITH_EOS)

    async def cleanup(self):
        """Cleans up resources used by the candidates."""
        await asyncio.gather(*[state.cleanup() for state in self])

    async def _adaptive_heal(self, next_byte: int):
        """Attempt adaptive token healing using TokenHealer.

        Returns a new beam advanced by `next_byte` if healing succeeds, else None.
        """
        healer = TokenHealer(
            max_backoff=self.params.heal_max_backoff,
            max_splits=self.params.heal_max_splits,
            verbose=self.params.verbose,
        )

        for state in self.states:
            healed_state = await healer.try_heal(state, next_byte)
            if healed_state is not None:
                return ByteBeamState([healed_state], self.params)

        return None

initial(llm, params, trie_opts=None) async classmethod

Creates initial beam state.

Parameters:

Name Type Description Default
llm StatefulTokenizedLM

Token-level language model to use.

required
params BeamParams

Beam search parameters.

required
trie_opts dict

Additional keyword arguments passed to AsyncTokenByteTrie.from_vocab. For example, {"max_batch_size": 100}.

None

Returns:

Type Description
ByteBeamState

Initial beam state.

Source code in genlm/bytes/byte_lm/beam.py
@classmethod
async def initial(cls, llm, params, trie_opts=None):
    """Creates initial beam state.

    Args:
        llm (StatefulTokenizedLM): Token-level language model to use.
        params (BeamParams): Beam search parameters.
        trie_opts (dict, optional): Additional keyword arguments passed to
            AsyncTokenByteTrie.from_vocab. For example, {"max_batch_size": 100}.

    Returns:
        (ByteBeamState): Initial beam state.
    """
    # Handle EOS tokens
    trie_opts = trie_opts or {}
    trie_opts["eos_tokens"] = params.eos_tokens

    async_trie = AsyncTokenByteTrie.from_vocab(
        get_byte_vocab(llm.tokenizer), **trie_opts
    )
    state = LazyTrieState.initial(llm, async_trie, mode=TrieMode.WITH_EOS)
    return cls([await state.materialize()], params)

logZ cached property

Estimate of the partition function (sum of weights) for current beam. This is the estimate of the prefix probability of the bytes consumed so far.

__lshift__(a) async

Advances the beam state with a new byte.

Parameters:

Name Type Description Default
a int

Byte to add to states.

required

Returns:

Type Description
ByteBeamState

New beam state after processing the byte.

Source code in genlm/bytes/byte_lm/beam.py
async def __lshift__(self, a):
    """Advances the beam state with a new byte.

    Args:
        a (int): Byte to add to states.

    Returns:
        (ByteBeamState): New beam state after processing the byte.
    """
    new_states = []
    for state in self:
        if new_state := state << a:
            new_states.append(new_state)

    logZ = logsumexp([s.weight for s in new_states]) if new_states else -np.inf
    for state in await self.extend(logZ):
        if new_state := state << a:
            new_states.append(new_state)

    new_state = ByteBeamState(new_states, self.params)

    # If advancing would empty the beam, do adaptive healing if enabled
    if self.params.heal and len(new_state) == 0:
        healed = await self._adaptive_heal(a)
        if healed is not None:
            if self.params.verbose:
                print("[heal] Applied adaptive token healing")
            return healed

    if self.params.verbose:
        print()
        print(new_state)

    return new_state

logp_next() async

Computes log probabilities for the next byte across all beam candidates.

Returns:

Type Description
LazyByteProbs

Log probabilities for next possible bytes.

Source code in genlm/bytes/byte_lm/beam.py
async def logp_next(self):
    """Computes log probabilities for the next byte across all beam candidates.

    Returns:
        (LazyByteProbs): Log probabilities for next possible bytes.
    """
    assert len(self) > 0, "Beam is empty"

    logqs = []
    for state in self:
        logqs.append(state.logp_next.ps + state.weight)

    for state in await self.extend(self.logZ):
        logqs.append(state.logp_next.ps + state.weight)

    logqs = np.stack(logqs, axis=0)  # shape: (num_states, array_size)
    # mask EOT positions of non-extended (EOT is at index 256)
    logqs[: len(self), -2] = -np.inf
    logps = scipy_logsumexp(logqs, axis=0)

    return LazyByteProbs(logps - logsumexp(logps))

extend(logZ) async

Attempts to advance each candidate in the beam by a token (EOT).

For each candididate with EOT available, this ends the current token and starts a new one in preparation for the next byte.

Parameters:

Name Type Description Default
logZ float

Current estimated of the partition function for pruning

required

Returns:

Type Description
list[LazyTrieState]

New candidate states after extension

Source code in genlm/bytes/byte_lm/beam.py
async def extend(self, logZ):
    """Attempts to advance each candidate in the beam by a token (EOT).

    For each candididate with EOT available, this ends the current token and
    starts a new one in preparation for the next byte.

    Args:
        logZ (float): Current estimated of the partition function for pruning

    Returns:
        (list[LazyTrieState]): New candidate states after extension
    """
    extends = []
    for state in self:
        if new_state := state.extend():
            logZ = np.logaddexp(logZ, new_state.weight)
            extends.append(new_state)

    coros = []
    for state in extends:
        if state.weight - logZ > self.params.log_prune_threshold:
            coros.append(state.materialize())

    return await asyncio.gather(*coros)

prune()

Prunes beam to maintain beam width and probability threshold.

Returns:

Type Description
ByteBeamState

New state with pruned candidates.

Source code in genlm/bytes/byte_lm/beam.py
def prune(self):
    """Prunes beam to maintain beam width and probability threshold.

    Returns:
        (ByteBeamState): New state with pruned candidates.
    """
    new_states = [
        state
        for state in self
        if state.weight - self.logZ > self.params.log_prune_threshold
    ][: self.params.K]
    return ByteBeamState(new_states, self.params)

with_mode(mode)

Create a new beam state with specified trie mode.

Parameters:

Name Type Description Default
mode TrieMode

Trie mode for the new beam state

required

Returns:

Type Description
ByteBeamState

New beam state with updated mode

Source code in genlm/bytes/byte_lm/beam.py
def with_mode(self, mode):
    """Create a new beam state with specified trie mode.

    Args:
        mode (TrieMode): Trie mode for the new beam state

    Returns:
        (ByteBeamState): New beam state with updated mode
    """
    return ByteBeamState(
        states=[state.with_mode(mode) for state in self.states],
        params=self.params,
    )

prefill(bs) async

Prefill the beam on a sequence of bytes.

During prefilling, EOS tokens are treated as normal tokens and don't cause termination.

Parameters:

Name Type Description Default
bs bytes

Byte sequence to prefill on

required

Returns:

Type Description
ByteBeamState

New beam state after prefilling

Source code in genlm/bytes/byte_lm/beam.py
async def prefill(self, bs):
    """Prefill the beam on a sequence of bytes.

    During prefilling, EOS tokens are treated as normal tokens and don't cause termination.

    Args:
        bs (bytes): Byte sequence to prefill on

    Returns:
        (ByteBeamState): New beam state after prefilling
    """
    # Create no_eos beam for prefill (EOS tokens treated as normal)
    no_eos_beam = self.with_mode(TrieMode.WITHOUT_EOS)

    # Do prefill operations on no_eos beam
    for b in bs:
        no_eos_beam = await (no_eos_beam.prune() << b)

    # Return as with_eos beam (EOS tokens get special handling after prefill)
    return no_eos_beam.with_mode(TrieMode.WITH_EOS)

cleanup() async

Cleans up resources used by the candidates.

Source code in genlm/bytes/byte_lm/beam.py
async def cleanup(self):
    """Cleans up resources used by the candidates."""
    await asyncio.gather(*[state.cleanup() for state in self])

BeamParams dataclass

Parameters for byte-level beam summing algorithm.

Parameters:

Name Type Description Default
K int

Beam width - maximum number of candidates to maintain.

required
prune_threshold float

Probability threshold for pruning candidates. Candidates with probability below this are removed. Defaults to 0.0

0.0
verbose bool

Whether to print the beam state at each step. Defaults to False

False
eos_tokens list[bytes]

List of tokens that should be treated as EOS. When configured, EOS tokens will terminate generation when sampled. Defaults to None

None
heal bool

Whether to enable adaptive token healing. Defaults to True

True
heal_max_backoff int

Maximum number of bytes to back off when healing. Defaults to None

None
heal_max_splits int

Maximum number of intra-suffix commits allowed during a single healing attempt. Defaults to None

None
Source code in genlm/bytes/byte_lm/beam.py
@dataclass
class BeamParams:
    """Parameters for byte-level beam summing algorithm.

    Args:
        K (int): Beam width - maximum number of candidates to maintain.
        prune_threshold (float, optional): Probability threshold for pruning candidates.
            Candidates with probability below this are removed. Defaults to 0.0
        verbose (bool, optional): Whether to print the beam state at each step. Defaults to False
        eos_tokens (list[bytes], optional): List of tokens that should be treated as EOS. When configured,
            EOS tokens will terminate generation when sampled. Defaults to None
        heal (bool, optional): Whether to enable adaptive token healing. Defaults to True
        heal_max_backoff (int, optional): Maximum number of bytes to back off when healing. Defaults to None
        heal_max_splits (int, optional): Maximum number of intra-suffix commits allowed during a single healing attempt. Defaults to None
    """

    K: int
    prune_threshold: float = 0.0
    verbose: bool = False
    eos_tokens: list[bytes] = None
    heal: bool = True
    heal_max_backoff: int | None = None
    # Optional cap on how many intra-partial commits are allowed during a
    # single healing attempt. None means unlimited. Set to 0 to disable
    # multi-split behavior (i.e., single-split only).
    heal_max_splits: int | None = None

    def __post_init__(self):
        if self.prune_threshold < 0:
            raise ValueError(
                f"prune_threshold must be non-negative, got {self.prune_threshold}"
            )
        self.log_prune_threshold = (
            np.log(self.prune_threshold) if self.prune_threshold > 0 else -np.inf
        )
        self.eos_tokens = set(self.eos_tokens) if self.eos_tokens else set()

LazyTrieState

A lazy-evaluated state of a TokenByteTrie traversal.

This class maintains the state of a language model while traversing a trie structure, lazily evaluating probabilities and maintaining the weight of the current path through the trie for beam search.

Parameters:

Name Type Description Default
lm_state StatefulTokenizedLM

Current language model state

required
trie TokenByteTrie

Trie structure mapping tokens to byte sequences

required
node int

Current node in the trie

required
weight float

Cumulative log probability of the path to this node

required
mass ndarray

Masses for each node in the trie for the current state

None
mode TrieMode

Trie mode to use

WITH_EOS
terminated bool

Whether the state is terminated (EOS has been consumed)

False
Source code in genlm/bytes/byte_lm/trie_state.py
class LazyTrieState:
    """A lazy-evaluated state of a TokenByteTrie traversal.

    This class maintains the state of a language model while traversing a trie structure,
    lazily evaluating probabilities and maintaining the weight of the current path through the trie
    for beam search.

    Args:
        lm_state (StatefulTokenizedLM): Current language model state
        trie (TokenByteTrie): Trie structure mapping tokens to byte sequences
        node (int): Current node in the trie
        weight (float): Cumulative log probability of the path to this node
        mass (numpy.ndarray, optional): Masses for each node in the trie for the current state
        mode (TrieMode): Trie mode to use
        terminated (bool): Whether the state is terminated (EOS has been consumed)
    """

    def __init__(
        self,
        lm_state,
        trie,
        node,
        weight,
        mass=None,
        mode=TrieMode.WITH_EOS,
        terminated=False,
    ):
        self.lm_state = lm_state
        self.trie = trie
        self.node = node
        self.weight = weight
        self._mass = mass
        self._extend = None
        self.mode = mode
        self.root = self.trie.trie.root
        self.children = self.trie.trie.children
        self.terminated = terminated

    @classmethod
    def initial(cls, lm, trie, mode=TrieMode.WITH_EOS):
        """Creates an initial trie state.

        Args:
            lm (genlm.backend.AsyncLM): Language model to use
            trie (TokenByteTrie): TokenByteTrie structure for byte-to-token mapping
            mode (TrieMode): Trie mode to use

        Returns:
            (LazyTrieState): Initial state at root of trie with weight 0.0
        """
        return cls(
            trie=trie,
            node=trie.trie.root,
            lm_state=StatefulTokenizedLM.initial(lm),
            weight=0.0,
            mode=mode,
        )

    @property
    def partial(self):
        """Returns the byte sequence corresponding to the current node in the trie."""
        return self.trie.trie.node2prefix[self.node]

    @property
    def mass(self):
        """Returns the log mass for each node in the trie.

        The mass at a node corresponds to the sum of the probabilities of all
        tokens which share the prefix (`self.partial`) represented by that node.

        Raises:
            ValueError: If state hasn't been materialized yet
        """
        if self._mass is None:
            raise ValueError("State is not yet materialized.")
        return self._mass

    def with_mode(self, mode):
        """Returns a new state with the given mode."""
        return LazyTrieState(
            lm_state=self.lm_state,
            trie=self.trie,
            node=self.node,
            weight=self.weight,
            mass=self._mass,
            mode=mode,
            terminated=self.terminated,
        )

    def actions(self):
        """Returns possible byte transitions from current node."""
        return self.children[self.node]

    def get_EOT(self):
        """Returns the end-of-token node if available from current position in the trie."""
        return self.children[self.node].get(self.trie.trie.eot_token)

    def __lshift__(self, b):
        """Transitions to a new state by consuming a byte.

        Args:
            b (int): Byte to consume

        Returns:
            (LazyTrieState|None): New state after consuming byte, or None if transition invalid (terminated or EOS)
        """
        if self.terminated:
            return None

        if node := self.children[self.node].get(b):
            mass = self.mass
            return LazyTrieState(
                lm_state=self.lm_state,
                trie=self.trie,
                mass=mass,
                node=node,
                weight=self.weight + mass[node] - mass[self.node],
                mode=self.mode,
                terminated=b == EOS,
            )

    def extend(self):
        """Extends current state by consuming an end-of-token if possible.

        Returns:
            (LazyTrieState|None): New state after consuming EOT, or None if not possible
        """
        if self._extend is None:
            if (eot_node := self.get_EOT()) is not None:
                mass = self.mass
                self._extend = LazyTrieState(
                    lm_state=self.lm_state
                    << int(self.trie.trie.leaf2token_id[eot_node]),
                    trie=self.trie,
                    node=self.root,
                    weight=self.weight + mass[eot_node] - mass[self.node],
                    mode=self.mode,
                )
        return self._extend

    @cached_property
    def logp_next(self):
        """Computes log probabilities for next possible transitions.

        Returns:
            (LazyByteProbs): Lazy log probability distribution over possible next bytes
        """
        logps = np.full(258, -np.inf)  # 258 for EOT, EOS + 256 for normal bytes
        mass = self.mass
        logZ = mass[self.node]

        for byte, node in self.actions().items():
            logps[byte if byte is not None else 256] = mass[node] - logZ

        return LazyByteProbs(logps)

    async def materialize(self):
        """Materializes the masses for each node in the trie for the current state.

        This makes a call to the language model and the underlying trie.

        Returns:
            (LazyTrieState): Self with materialized masses
        """
        if self._mass is None:
            logp_next = await self.lm_state.logp_next()
            log_mass = await self.trie.weight_sum(torch.exp(logp_next), self.mode)
            mass = torch.log(log_mass)
            self._mass = mass.cpu().numpy()
        return self

    def __repr__(self):
        context = colors.green % ("|" + escape(bytes(self.partial)))
        if self.terminated:
            context += colors.yellow % "<EOS>"
        return f"{self.weight:.2f}: {self.lm_state}" + context

    async def cleanup(self):
        """Cleans up resources used by the trie."""
        await self.trie.cleanup()

initial(lm, trie, mode=TrieMode.WITH_EOS) classmethod

Creates an initial trie state.

Parameters:

Name Type Description Default
lm AsyncLM

Language model to use

required
trie TokenByteTrie

TokenByteTrie structure for byte-to-token mapping

required
mode TrieMode

Trie mode to use

WITH_EOS

Returns:

Type Description
LazyTrieState

Initial state at root of trie with weight 0.0

Source code in genlm/bytes/byte_lm/trie_state.py
@classmethod
def initial(cls, lm, trie, mode=TrieMode.WITH_EOS):
    """Creates an initial trie state.

    Args:
        lm (genlm.backend.AsyncLM): Language model to use
        trie (TokenByteTrie): TokenByteTrie structure for byte-to-token mapping
        mode (TrieMode): Trie mode to use

    Returns:
        (LazyTrieState): Initial state at root of trie with weight 0.0
    """
    return cls(
        trie=trie,
        node=trie.trie.root,
        lm_state=StatefulTokenizedLM.initial(lm),
        weight=0.0,
        mode=mode,
    )

partial property

Returns the byte sequence corresponding to the current node in the trie.

mass property

Returns the log mass for each node in the trie.

The mass at a node corresponds to the sum of the probabilities of all tokens which share the prefix (self.partial) represented by that node.

Raises:

Type Description
ValueError

If state hasn't been materialized yet

with_mode(mode)

Returns a new state with the given mode.

Source code in genlm/bytes/byte_lm/trie_state.py
def with_mode(self, mode):
    """Returns a new state with the given mode."""
    return LazyTrieState(
        lm_state=self.lm_state,
        trie=self.trie,
        node=self.node,
        weight=self.weight,
        mass=self._mass,
        mode=mode,
        terminated=self.terminated,
    )

actions()

Returns possible byte transitions from current node.

Source code in genlm/bytes/byte_lm/trie_state.py
def actions(self):
    """Returns possible byte transitions from current node."""
    return self.children[self.node]

get_EOT()

Returns the end-of-token node if available from current position in the trie.

Source code in genlm/bytes/byte_lm/trie_state.py
def get_EOT(self):
    """Returns the end-of-token node if available from current position in the trie."""
    return self.children[self.node].get(self.trie.trie.eot_token)

__lshift__(b)

Transitions to a new state by consuming a byte.

Parameters:

Name Type Description Default
b int

Byte to consume

required

Returns:

Type Description
LazyTrieState | None

New state after consuming byte, or None if transition invalid (terminated or EOS)

Source code in genlm/bytes/byte_lm/trie_state.py
def __lshift__(self, b):
    """Transitions to a new state by consuming a byte.

    Args:
        b (int): Byte to consume

    Returns:
        (LazyTrieState|None): New state after consuming byte, or None if transition invalid (terminated or EOS)
    """
    if self.terminated:
        return None

    if node := self.children[self.node].get(b):
        mass = self.mass
        return LazyTrieState(
            lm_state=self.lm_state,
            trie=self.trie,
            mass=mass,
            node=node,
            weight=self.weight + mass[node] - mass[self.node],
            mode=self.mode,
            terminated=b == EOS,
        )

extend()

Extends current state by consuming an end-of-token if possible.

Returns:

Type Description
LazyTrieState | None

New state after consuming EOT, or None if not possible

Source code in genlm/bytes/byte_lm/trie_state.py
def extend(self):
    """Extends current state by consuming an end-of-token if possible.

    Returns:
        (LazyTrieState|None): New state after consuming EOT, or None if not possible
    """
    if self._extend is None:
        if (eot_node := self.get_EOT()) is not None:
            mass = self.mass
            self._extend = LazyTrieState(
                lm_state=self.lm_state
                << int(self.trie.trie.leaf2token_id[eot_node]),
                trie=self.trie,
                node=self.root,
                weight=self.weight + mass[eot_node] - mass[self.node],
                mode=self.mode,
            )
    return self._extend

logp_next cached property

Computes log probabilities for next possible transitions.

Returns:

Type Description
LazyByteProbs

Lazy log probability distribution over possible next bytes

materialize() async

Materializes the masses for each node in the trie for the current state.

This makes a call to the language model and the underlying trie.

Returns:

Type Description
LazyTrieState

Self with materialized masses

Source code in genlm/bytes/byte_lm/trie_state.py
async def materialize(self):
    """Materializes the masses for each node in the trie for the current state.

    This makes a call to the language model and the underlying trie.

    Returns:
        (LazyTrieState): Self with materialized masses
    """
    if self._mass is None:
        logp_next = await self.lm_state.logp_next()
        log_mass = await self.trie.weight_sum(torch.exp(logp_next), self.mode)
        mass = torch.log(log_mass)
        self._mass = mass.cpu().numpy()
    return self

cleanup() async

Cleans up resources used by the trie.

Source code in genlm/bytes/byte_lm/trie_state.py
async def cleanup(self):
    """Cleans up resources used by the trie."""
    await self.trie.cleanup()

StatefulTokenizedLM

A stateful tokenized language model that maintains context and generates next token logprobs.

Parameters:

Name Type Description Default
model AsyncLM

The underlying language model

required
context list

List of token IDs representing the current context

required
n_calls int

Number of times the model has been called

0
max_context_length int

Maximum length of context to maintain

None
Source code in genlm/bytes/byte_lm/lm_state.py
class StatefulTokenizedLM:
    """A stateful tokenized language model that maintains context and generates next token logprobs.

    Args:
        model (genlm.backend.AsyncLM): The underlying language model
        context (list): List of token IDs representing the current context
        n_calls (int): Number of times the model has been called
        max_context_length (int, optional): Maximum length of context to maintain
    """

    def __init__(self, model, context, n_calls=0, max_context_length=None):
        self.model = model
        self.context = context
        self._n_calls = n_calls
        self.max_context_length = max_context_length

    @classmethod
    def initial(cls, model, initial_context=None, max_context_length=None):
        """Creates an initial state for the language model.

        Args:
            model (genlm.backend.AsyncLM): The language model to use
            initial_context (list, optional): Initial context of token IDs. Defaults to [tokenizer.bos_token_id]
            max_context_length (int, optional): Maximum context length to maintain

        Returns:
            (StatefulTokenizedLM): A new instance with initial state
        """
        if initial_context is None:
            initial_context = [model.tokenizer.bos_token_id]
        return cls(model, initial_context, max_context_length=max_context_length)

    def __lshift__(self, token):
        """Adds a new token to the context and returns a new state.

        Args:
            token (int): Token ID to add to context

        Returns:
            (StatefulTokenizedLM): New state with updated context
        """
        assert isinstance(token, int)
        if (
            self.max_context_length is not None
            and len(self.context) >= self.max_context_length
        ):
            self.context = self.context[-(self.max_context_length - 1) :]
        return StatefulTokenizedLM(
            self.model, self.context + [token], n_calls=self._n_calls
        )

    async def logp_next(self):
        """Computes log probabilities for the next token given the current context.

        Returns:
            (torch.Tensor): Log probabilities for next tokens
        """
        self._n_calls += 1
        return await self.model.next_token_logprobs(self.context)

    def __repr__(self):
        return colors.purple % (
            "|".join([escape(self.model.byte_vocab[x]) for x in self.context])
        )

initial(model, initial_context=None, max_context_length=None) classmethod

Creates an initial state for the language model.

Parameters:

Name Type Description Default
model AsyncLM

The language model to use

required
initial_context list

Initial context of token IDs. Defaults to [tokenizer.bos_token_id]

None
max_context_length int

Maximum context length to maintain

None

Returns:

Type Description
StatefulTokenizedLM

A new instance with initial state

Source code in genlm/bytes/byte_lm/lm_state.py
@classmethod
def initial(cls, model, initial_context=None, max_context_length=None):
    """Creates an initial state for the language model.

    Args:
        model (genlm.backend.AsyncLM): The language model to use
        initial_context (list, optional): Initial context of token IDs. Defaults to [tokenizer.bos_token_id]
        max_context_length (int, optional): Maximum context length to maintain

    Returns:
        (StatefulTokenizedLM): A new instance with initial state
    """
    if initial_context is None:
        initial_context = [model.tokenizer.bos_token_id]
    return cls(model, initial_context, max_context_length=max_context_length)

__lshift__(token)

Adds a new token to the context and returns a new state.

Parameters:

Name Type Description Default
token int

Token ID to add to context

required

Returns:

Type Description
StatefulTokenizedLM

New state with updated context

Source code in genlm/bytes/byte_lm/lm_state.py
def __lshift__(self, token):
    """Adds a new token to the context and returns a new state.

    Args:
        token (int): Token ID to add to context

    Returns:
        (StatefulTokenizedLM): New state with updated context
    """
    assert isinstance(token, int)
    if (
        self.max_context_length is not None
        and len(self.context) >= self.max_context_length
    ):
        self.context = self.context[-(self.max_context_length - 1) :]
    return StatefulTokenizedLM(
        self.model, self.context + [token], n_calls=self._n_calls
    )

logp_next() async

Computes log probabilities for the next token given the current context.

Returns:

Type Description
Tensor

Log probabilities for next tokens

Source code in genlm/bytes/byte_lm/lm_state.py
async def logp_next(self):
    """Computes log probabilities for the next token given the current context.

    Returns:
        (torch.Tensor): Log probabilities for next tokens
    """
    self._n_calls += 1
    return await self.model.next_token_logprobs(self.context)