Skip to content

bytes

genlm.bytes

ByteBeamState

Bases: StatefulByteLM

Represents the state of the beam during byte-level language modeling.

Tracks multiple candidate states and their probabilities, pruning low-probability candidates.

Parameters:

Name Type Description Default
states list[LazyTrieState]

List of candidate states to track

required
params BeamParams

Parameters controlling beam search behavior

required
Source code in genlm/bytes/byte_lm/beam.py
class ByteBeamState(StatefulByteLM):
    """Represents the state of the beam during byte-level language modeling.

    Tracks multiple candidate states and their probabilities, pruning low-probability
    candidates.

    Args:
        states (list[LazyTrieState]): List of candidate states to track
        params (BeamParams): Parameters controlling beam search behavior
    """

    def __init__(self, states, params):
        self.states = sorted(states, key=lambda b: -b.weight)
        self.params = params

    @classmethod
    async def initial(cls, llm, params, trie_opts=None):
        """Creates initial beam state.

        Args:
            llm (StatefulTokenizedLM): Token-level language model to use.
            params (BeamParams): Beam search parameters.
            trie_opts (dict, optional): Additional keyword arguments passed to
                AsyncTokenByteTrie.from_vocab. For example, {"max_batch_size": 100}.

        Returns:
            (ByteBeamState): Initial beam state.
        """
        state = LazyTrieState.initial(
            llm,
            AsyncTokenByteTrie.from_vocab(
                get_byte_vocab(llm.tokenizer), **(trie_opts or {})
            ),
        )
        return cls([await state.materialize()], params)

    def __iter__(self):
        return iter(self.states)

    def __len__(self):
        return len(self.states)

    @cached_property
    def logZ(self):
        """Estimate of the partition function (sum of weights) for current beam.
        This is the estimate of the prefix probability of the bytes consumed so far.
        """
        return logsumexp([state.weight for state in self])

    async def __lshift__(self, a):
        """Advances the beam state with a new byte.

        Args:
            a (int): Byte to add to states.

        Returns:
            (ByteBeamState): New beam state after processing the byte.
        """
        new_states = []
        for state in self:
            if new_state := state << a:
                new_states.append(new_state)

        logZ = logsumexp([s.weight for s in new_states])
        for state in await self.extend(logZ):
            if new_state := state << a:
                new_states.append(new_state)

        new_state = ByteBeamState(new_states, self.params)

        if self.params.verbose:
            print()
            print(new_state)

        return new_state

    async def logp_next(self):
        """Computes log probabilities for the next byte across all beam candidates.

        Returns:
            (LazyByteProbs): Log probabilities for next possible bytes.
        """
        assert len(self) > 0, "Beam is empty"

        logqs = []
        for state in self:
            logqs.append(state.logp_next.ps + state.weight)

        for state in await self.extend(self.logZ):
            logqs.append(state.logp_next.ps + state.weight)

        logqs = np.stack(logqs, axis=0)  # shape: (num_states, 257)
        logqs[: len(self), -1] = -np.inf  # mask EOT positions of non-extended
        logps = scipy_logsumexp(logqs, axis=0)

        return LazyByteProbs(logps - logsumexp(logps))

    async def extend(self, logZ):
        """Attempts to advance each candidate in the beam by a token (EOT).

        For each candididate with EOT available, this ends the current token and
        starts a new one in preparation for the next byte.

        Args:
            logZ (float): Current estimated of the partition function for pruning

        Returns:
            (list[LazyTrieState]): New candidate states after extension
        """
        extends = []
        for state in self:
            if new_state := state.extend():
                logZ = np.logaddexp(logZ, new_state.weight)
                extends.append(new_state)

        coros = []
        for state in extends:
            if state.weight - logZ > self.params.log_prune_threshold:
                coros.append(state.materialize())

        return await asyncio.gather(*coros)

    def prune(self):
        """Prunes beam to maintain beam width and probability threshold.

        Returns:
            (ByteBeamState): New state with pruned candidates.
        """
        new_states = [
            state
            for state in self
            if state.weight - self.logZ > self.params.log_prune_threshold
        ][: self.params.K]
        return ByteBeamState(new_states, self.params)

    def __repr__(self):
        desc = colors.bold % f"Z: {self.logZ}\n" + colors.bold % "Candidates:\n"
        for state in self:
            P = np.exp(state.weight - self.logZ)
            color = colors.green if P > self.params.prune_threshold else colors.red
            desc += f"({color % f'{P:.4f}'}) {repr(state)}\n"
        return desc

    async def cleanup(self):
        """Cleans up resources used by the candidates."""
        await asyncio.gather(*[state.cleanup() for state in self])

initial(llm, params, trie_opts=None) async classmethod

Creates initial beam state.

Parameters:

Name Type Description Default
llm StatefulTokenizedLM

Token-level language model to use.

required
params BeamParams

Beam search parameters.

required
trie_opts dict

Additional keyword arguments passed to AsyncTokenByteTrie.from_vocab. For example, {"max_batch_size": 100}.

None

Returns:

Type Description
ByteBeamState

Initial beam state.

Source code in genlm/bytes/byte_lm/beam.py
@classmethod
async def initial(cls, llm, params, trie_opts=None):
    """Creates initial beam state.

    Args:
        llm (StatefulTokenizedLM): Token-level language model to use.
        params (BeamParams): Beam search parameters.
        trie_opts (dict, optional): Additional keyword arguments passed to
            AsyncTokenByteTrie.from_vocab. For example, {"max_batch_size": 100}.

    Returns:
        (ByteBeamState): Initial beam state.
    """
    state = LazyTrieState.initial(
        llm,
        AsyncTokenByteTrie.from_vocab(
            get_byte_vocab(llm.tokenizer), **(trie_opts or {})
        ),
    )
    return cls([await state.materialize()], params)

logZ cached property

Estimate of the partition function (sum of weights) for current beam. This is the estimate of the prefix probability of the bytes consumed so far.

__lshift__(a) async

Advances the beam state with a new byte.

Parameters:

Name Type Description Default
a int

Byte to add to states.

required

Returns:

Type Description
ByteBeamState

New beam state after processing the byte.

Source code in genlm/bytes/byte_lm/beam.py
async def __lshift__(self, a):
    """Advances the beam state with a new byte.

    Args:
        a (int): Byte to add to states.

    Returns:
        (ByteBeamState): New beam state after processing the byte.
    """
    new_states = []
    for state in self:
        if new_state := state << a:
            new_states.append(new_state)

    logZ = logsumexp([s.weight for s in new_states])
    for state in await self.extend(logZ):
        if new_state := state << a:
            new_states.append(new_state)

    new_state = ByteBeamState(new_states, self.params)

    if self.params.verbose:
        print()
        print(new_state)

    return new_state

logp_next() async

Computes log probabilities for the next byte across all beam candidates.

Returns:

Type Description
LazyByteProbs

Log probabilities for next possible bytes.

Source code in genlm/bytes/byte_lm/beam.py
async def logp_next(self):
    """Computes log probabilities for the next byte across all beam candidates.

    Returns:
        (LazyByteProbs): Log probabilities for next possible bytes.
    """
    assert len(self) > 0, "Beam is empty"

    logqs = []
    for state in self:
        logqs.append(state.logp_next.ps + state.weight)

    for state in await self.extend(self.logZ):
        logqs.append(state.logp_next.ps + state.weight)

    logqs = np.stack(logqs, axis=0)  # shape: (num_states, 257)
    logqs[: len(self), -1] = -np.inf  # mask EOT positions of non-extended
    logps = scipy_logsumexp(logqs, axis=0)

    return LazyByteProbs(logps - logsumexp(logps))

extend(logZ) async

Attempts to advance each candidate in the beam by a token (EOT).

For each candididate with EOT available, this ends the current token and starts a new one in preparation for the next byte.

Parameters:

Name Type Description Default
logZ float

Current estimated of the partition function for pruning

required

Returns:

Type Description
list[LazyTrieState]

New candidate states after extension

Source code in genlm/bytes/byte_lm/beam.py
async def extend(self, logZ):
    """Attempts to advance each candidate in the beam by a token (EOT).

    For each candididate with EOT available, this ends the current token and
    starts a new one in preparation for the next byte.

    Args:
        logZ (float): Current estimated of the partition function for pruning

    Returns:
        (list[LazyTrieState]): New candidate states after extension
    """
    extends = []
    for state in self:
        if new_state := state.extend():
            logZ = np.logaddexp(logZ, new_state.weight)
            extends.append(new_state)

    coros = []
    for state in extends:
        if state.weight - logZ > self.params.log_prune_threshold:
            coros.append(state.materialize())

    return await asyncio.gather(*coros)

prune()

Prunes beam to maintain beam width and probability threshold.

Returns:

Type Description
ByteBeamState

New state with pruned candidates.

Source code in genlm/bytes/byte_lm/beam.py
def prune(self):
    """Prunes beam to maintain beam width and probability threshold.

    Returns:
        (ByteBeamState): New state with pruned candidates.
    """
    new_states = [
        state
        for state in self
        if state.weight - self.logZ > self.params.log_prune_threshold
    ][: self.params.K]
    return ByteBeamState(new_states, self.params)

cleanup() async

Cleans up resources used by the candidates.

Source code in genlm/bytes/byte_lm/beam.py
async def cleanup(self):
    """Cleans up resources used by the candidates."""
    await asyncio.gather(*[state.cleanup() for state in self])

LazyTrieState

A lazy-evaluated state of a TokenByteTrie traversal.

This class maintains the state of a language model while traversing a trie structure, lazily evaluating probabilities and maintaining the weight of the current path through the trie for beam search.

Parameters:

Name Type Description Default
lm_state StatefulTokenizedLM

Current language model state

required
trie TokenByteTrie

Trie structure mapping tokens to byte sequences

required
node int

Current node in the trie

required
weight float

Cumulative log probability of the path to this node

required
mass ndarray

Masses for each node in the trie for the current state

None
Source code in genlm/bytes/byte_lm/trie_state.py
class LazyTrieState:
    """A lazy-evaluated state of a TokenByteTrie traversal.

    This class maintains the state of a language model while traversing a trie structure,
    lazily evaluating probabilities and maintaining the weight of the current path through the trie
    for beam search.

    Args:
        lm_state (StatefulTokenizedLM): Current language model state
        trie (TokenByteTrie): Trie structure mapping tokens to byte sequences
        node (int): Current node in the trie
        weight (float): Cumulative log probability of the path to this node
        mass (numpy.ndarray, optional): Masses for each node in the trie for the current state
    """

    def __init__(self, lm_state, trie, node, weight, mass=None):
        self.lm_state = lm_state
        self.trie = trie
        self.node = node
        self.weight = weight
        self._mass = mass
        self._extend = None
        self.root = self.trie.trie.root
        self.children = self.trie.trie.children

    @classmethod
    def initial(cls, lm, trie):
        """Creates an initial trie state.

        Args:
            lm (genlm.backend.AsyncLM): Language model to use
            trie (TokenByteTrie): TokenByteTrie structure for byte-to-token mapping

        Returns:
            (LazyTrieState): Initial state at root of trie with weight 0.0
        """
        return cls(
            trie=trie,
            node=trie.trie.root,
            lm_state=StatefulTokenizedLM.initial(lm),
            weight=0.0,
        )

    @property
    def partial(self):
        """Returns the byte sequence corresponding to the current node in the trie."""
        return self.trie.trie.node2prefix[self.node]

    @property
    def mass(self):
        """Returns the log mass for each node in the trie.

        The mass at a node corresponds to the sum of the probabilities of all
        tokens which share the prefix (`self.partial`) represented by that node.

        Raises:
            ValueError: If state hasn't been materialized yet
        """
        if self._mass is None:
            raise ValueError("State is not yet materialized.")
        return self._mass

    def actions(self):
        """Returns possible byte transitions from current node."""
        return self.children[self.node]

    def get_EOT(self):
        """Returns the end-of-token node if available from current position in the trie."""
        return self.children[self.node].get(self.trie.trie.eot_token)

    def __lshift__(self, b):
        """Transitions to a new state by consuming a byte.

        Args:
            b (int): Byte to consume

        Returns:
            (LazyTrieState|None): New state after consuming byte, or None if transition invalid
        """
        if node := self.children[self.node].get(b):
            mass = self.mass
            return LazyTrieState(
                lm_state=self.lm_state,
                trie=self.trie,
                mass=mass,
                node=node,
                weight=self.weight + mass[node] - mass[self.node],
            )

    def extend(self):
        """Extends current state by consuming an end-of-token if possible.

        Returns:
            (LazyTrieState|None): New state after consuming EOT, or None if not possible
        """
        if self._extend is None:
            if eot_node := self.get_EOT():
                mass = self.mass
                self._extend = LazyTrieState(
                    lm_state=self.lm_state
                    << int(self.trie.trie.leaf2token_id[eot_node]),
                    trie=self.trie,
                    node=self.root,
                    weight=self.weight + mass[eot_node] - mass[self.node],
                )
        return self._extend

    @cached_property
    def logp_next(self):
        """Computes log probabilities for next possible transitions.

        Returns:
            (LazyByteProbs): Lazy log probability distribution over possible next bytes
        """
        logps = np.full(257, -np.inf)  # 257 for EOT
        mass = self.mass
        logZ = mass[self.node]
        for byte, node in self.actions().items():
            logps[byte if byte is not None else 256] = mass[node] - logZ
        return LazyByteProbs(logps)

    async def materialize(self):
        """Materializes the masses for each node in the trie for the current state.

        This makes a call to the language model and the underlying trie.

        Returns:
            (LazyTrieState): Self with materialized masses
        """
        if self._mass is None:
            logp_next = await self.lm_state.logp_next()
            log_mass = await self.trie.weight_sum(torch.exp(logp_next))
            mass = torch.log(log_mass)
            self._mass = mass.cpu().numpy()
        return self

    def __repr__(self):
        context = colors.green % ("|" + escape(bytes(self.partial)))
        return f"{self.weight:.2f}: {self.lm_state}" + context

    async def cleanup(self):
        """Cleans up resources used by the trie."""
        await self.trie.cleanup()

initial(lm, trie) classmethod

Creates an initial trie state.

Parameters:

Name Type Description Default
lm AsyncLM

Language model to use

required
trie TokenByteTrie

TokenByteTrie structure for byte-to-token mapping

required

Returns:

Type Description
LazyTrieState

Initial state at root of trie with weight 0.0

Source code in genlm/bytes/byte_lm/trie_state.py
@classmethod
def initial(cls, lm, trie):
    """Creates an initial trie state.

    Args:
        lm (genlm.backend.AsyncLM): Language model to use
        trie (TokenByteTrie): TokenByteTrie structure for byte-to-token mapping

    Returns:
        (LazyTrieState): Initial state at root of trie with weight 0.0
    """
    return cls(
        trie=trie,
        node=trie.trie.root,
        lm_state=StatefulTokenizedLM.initial(lm),
        weight=0.0,
    )

partial property

Returns the byte sequence corresponding to the current node in the trie.

mass property

Returns the log mass for each node in the trie.

The mass at a node corresponds to the sum of the probabilities of all tokens which share the prefix (self.partial) represented by that node.

Raises:

Type Description
ValueError

If state hasn't been materialized yet

actions()

Returns possible byte transitions from current node.

Source code in genlm/bytes/byte_lm/trie_state.py
def actions(self):
    """Returns possible byte transitions from current node."""
    return self.children[self.node]

get_EOT()

Returns the end-of-token node if available from current position in the trie.

Source code in genlm/bytes/byte_lm/trie_state.py
def get_EOT(self):
    """Returns the end-of-token node if available from current position in the trie."""
    return self.children[self.node].get(self.trie.trie.eot_token)

__lshift__(b)

Transitions to a new state by consuming a byte.

Parameters:

Name Type Description Default
b int

Byte to consume

required

Returns:

Type Description
LazyTrieState | None

New state after consuming byte, or None if transition invalid

Source code in genlm/bytes/byte_lm/trie_state.py
def __lshift__(self, b):
    """Transitions to a new state by consuming a byte.

    Args:
        b (int): Byte to consume

    Returns:
        (LazyTrieState|None): New state after consuming byte, or None if transition invalid
    """
    if node := self.children[self.node].get(b):
        mass = self.mass
        return LazyTrieState(
            lm_state=self.lm_state,
            trie=self.trie,
            mass=mass,
            node=node,
            weight=self.weight + mass[node] - mass[self.node],
        )

extend()

Extends current state by consuming an end-of-token if possible.

Returns:

Type Description
LazyTrieState | None

New state after consuming EOT, or None if not possible

Source code in genlm/bytes/byte_lm/trie_state.py
def extend(self):
    """Extends current state by consuming an end-of-token if possible.

    Returns:
        (LazyTrieState|None): New state after consuming EOT, or None if not possible
    """
    if self._extend is None:
        if eot_node := self.get_EOT():
            mass = self.mass
            self._extend = LazyTrieState(
                lm_state=self.lm_state
                << int(self.trie.trie.leaf2token_id[eot_node]),
                trie=self.trie,
                node=self.root,
                weight=self.weight + mass[eot_node] - mass[self.node],
            )
    return self._extend

logp_next cached property

Computes log probabilities for next possible transitions.

Returns:

Type Description
LazyByteProbs

Lazy log probability distribution over possible next bytes

materialize() async

Materializes the masses for each node in the trie for the current state.

This makes a call to the language model and the underlying trie.

Returns:

Type Description
LazyTrieState

Self with materialized masses

Source code in genlm/bytes/byte_lm/trie_state.py
async def materialize(self):
    """Materializes the masses for each node in the trie for the current state.

    This makes a call to the language model and the underlying trie.

    Returns:
        (LazyTrieState): Self with materialized masses
    """
    if self._mass is None:
        logp_next = await self.lm_state.logp_next()
        log_mass = await self.trie.weight_sum(torch.exp(logp_next))
        mass = torch.log(log_mass)
        self._mass = mass.cpu().numpy()
    return self

cleanup() async

Cleans up resources used by the trie.

Source code in genlm/bytes/byte_lm/trie_state.py
async def cleanup(self):
    """Cleans up resources used by the trie."""
    await self.trie.cleanup()

StatefulTokenizedLM

A stateful tokenized language model that maintains context and generates next token logprobs.

Parameters:

Name Type Description Default
model AsyncLM

The underlying language model

required
context list

List of token IDs representing the current context

required
n_calls int

Number of times the model has been called

0
max_context_length int

Maximum length of context to maintain

None
Source code in genlm/bytes/byte_lm/lm_state.py
class StatefulTokenizedLM:
    """A stateful tokenized language model that maintains context and generates next token logprobs.

    Args:
        model (genlm.backend.AsyncLM): The underlying language model
        context (list): List of token IDs representing the current context
        n_calls (int): Number of times the model has been called
        max_context_length (int, optional): Maximum length of context to maintain
    """

    def __init__(self, model, context, n_calls=0, max_context_length=None):
        self.model = model
        self.context = context
        self._n_calls = n_calls
        self.max_context_length = max_context_length

    @classmethod
    def initial(cls, model, initial_context=None, max_context_length=None):
        """Creates an initial state for the language model.

        Args:
            model (genlm.backend.AsyncLM): The language model to use
            initial_context (list, optional): Initial context of token IDs. Defaults to [tokenizer.bos_token_id]
            max_context_length (int, optional): Maximum context length to maintain

        Returns:
            (StatefulTokenizedLM): A new instance with initial state
        """
        if initial_context is None:
            initial_context = [model.tokenizer.bos_token_id]
        return cls(model, initial_context, max_context_length=max_context_length)

    def __lshift__(self, token):
        """Adds a new token to the context and returns a new state.

        Args:
            token (int): Token ID to add to context

        Returns:
            (StatefulTokenizedLM): New state with updated context
        """
        assert isinstance(token, int)
        if (
            self.max_context_length is not None
            and len(self.context) >= self.max_context_length
        ):
            self.context = self.context[-(self.max_context_length - 1) :]
        return StatefulTokenizedLM(
            self.model, self.context + [token], n_calls=self._n_calls
        )

    async def logp_next(self):
        """Computes log probabilities for the next token given the current context.

        Returns:
            (torch.Tensor): Log probabilities for next tokens
        """
        self._n_calls += 1
        return await self.model.next_token_logprobs(self.context)

    def __repr__(self):
        return colors.purple % (
            "|".join([escape(self.model.byte_vocab[x]) for x in self.context])
        )

initial(model, initial_context=None, max_context_length=None) classmethod

Creates an initial state for the language model.

Parameters:

Name Type Description Default
model AsyncLM

The language model to use

required
initial_context list

Initial context of token IDs. Defaults to [tokenizer.bos_token_id]

None
max_context_length int

Maximum context length to maintain

None

Returns:

Type Description
StatefulTokenizedLM

A new instance with initial state

Source code in genlm/bytes/byte_lm/lm_state.py
@classmethod
def initial(cls, model, initial_context=None, max_context_length=None):
    """Creates an initial state for the language model.

    Args:
        model (genlm.backend.AsyncLM): The language model to use
        initial_context (list, optional): Initial context of token IDs. Defaults to [tokenizer.bos_token_id]
        max_context_length (int, optional): Maximum context length to maintain

    Returns:
        (StatefulTokenizedLM): A new instance with initial state
    """
    if initial_context is None:
        initial_context = [model.tokenizer.bos_token_id]
    return cls(model, initial_context, max_context_length=max_context_length)

__lshift__(token)

Adds a new token to the context and returns a new state.

Parameters:

Name Type Description Default
token int

Token ID to add to context

required

Returns:

Type Description
StatefulTokenizedLM

New state with updated context

Source code in genlm/bytes/byte_lm/lm_state.py
def __lshift__(self, token):
    """Adds a new token to the context and returns a new state.

    Args:
        token (int): Token ID to add to context

    Returns:
        (StatefulTokenizedLM): New state with updated context
    """
    assert isinstance(token, int)
    if (
        self.max_context_length is not None
        and len(self.context) >= self.max_context_length
    ):
        self.context = self.context[-(self.max_context_length - 1) :]
    return StatefulTokenizedLM(
        self.model, self.context + [token], n_calls=self._n_calls
    )

logp_next() async

Computes log probabilities for the next token given the current context.

Returns:

Type Description
Tensor

Log probabilities for next tokens

Source code in genlm/bytes/byte_lm/lm_state.py
async def logp_next(self):
    """Computes log probabilities for the next token given the current context.

    Returns:
        (torch.Tensor): Log probabilities for next tokens
    """
    self._n_calls += 1
    return await self.model.next_token_logprobs(self.context)

BeamParams dataclass

Parameters for byte-level beam summing algorithm.

Parameters:

Name Type Description Default
K int

Beam width - maximum number of candidates to maintain.

required
prune_threshold float

Probability threshold for pruning candidates. Candidates with probability below this are removed. Defaults to 0.0

0.0
verbose bool

Whether to print the beam state at each step. Defaults to False

False
Source code in genlm/bytes/byte_lm/beam.py
@dataclass
class BeamParams:
    """Parameters for byte-level beam summing algorithm.

    Args:
        K (int): Beam width - maximum number of candidates to maintain.
        prune_threshold (float, optional): Probability threshold for pruning candidates.
            Candidates with probability below this are removed. Defaults to 0.0
        verbose (bool, optional): Whether to print the beam state at each step. Defaults to False
    """

    K: int
    prune_threshold: float = 0.0
    verbose: bool = False

    def __post_init__(self):
        if self.prune_threshold < 0:
            raise ValueError(
                f"prune_threshold must be non-negative, got {self.prune_threshold}"
            )
        self.log_prune_threshold = (
            np.log(self.prune_threshold) if self.prune_threshold > 0 else -np.inf
        )

TokenByteTrie

A trie data structure for efficient token-to-byte mapping.

Source code in genlm/bytes/trie.py
class TokenByteTrie:
    """A trie data structure for efficient token-to-byte mapping."""

    def __init__(
        self, decode, device=None, atomic_tokens=None, eot_token=None, max_batch_size=64
    ):
        """Initialize a `TokenByteTrie`.

        Args:
            decode (list[bytes]): List representing the token vocabulary.
            device (str, optional): Device to use for weight sum and max computations ('cpu' or 'cuda').
            atomic_tokens (list[bytes], optional): List of tokens that should be treated as atomic units rather than being split into bytes.
            eot_token (bytes|None, optional): End-of-token token. Default is None, which represents EOT as None.
            max_batch_size (int, optional): Maximum batch size for weight sum sparse matrix multiplication.
        """
        self.decode = decode
        self.max_batch_size = max_batch_size

        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        if self.device not in ["cpu", "cuda"]:
            raise ValueError(f"Invalid device: {device}. Must be 'cpu', 'cuda' or None")

        self.eot_token = eot_token
        self._build_trie(atomic_tokens or [])
        self._renumber()
        self._build_node2prefix()
        self._build_reachability_matrix()
        self.token_ids = torch.tensor(
            self.token_id_to_leaf[:, 0], dtype=torch.long, device=self.device
        )

    def _build_trie(self, atomic_tokens):
        """Builds a trie data structure from the vocabulary.

        Returns:
            (dict): A dictionary where keys are token IDs and values are lists of characters.
        """
        for token in atomic_tokens:
            if token not in self.decode:
                raise ValueError(f"Atomic token {token} not in vocabulary")

        self.word2leaf = {}
        self.children = [{}]  # First node is root
        self.root = 0
        self.token_id_to_leaf = []
        self.lookup = {}

        for token_id, word in enumerate(self.decode):
            if word in self.lookup:
                raise ValueError(f"Duplicate word in vocabulary: {word}")
            self.lookup[word] = token_id

            curr = self.root
            letters = [word] if word in atomic_tokens else word
            for letter in letters:
                if letter not in self.children[curr]:
                    self.children[curr][letter] = len(self.children)
                    self.children.append({})
                curr = self.children[curr][letter]

            self.children[curr][self.eot_token] = last = len(self.children)
            self.children.append({})
            assert word not in self.word2leaf
            self.word2leaf[word] = last
            self.token_id_to_leaf.append((token_id, last))

        self.leaf2word = dict(zip(self.word2leaf.values(), self.word2leaf.keys()))
        self.jump = [
            np.array(sorted(x.values()), dtype=np.int32) for x in self.children
        ]

    def _renumber(self):
        """Renumber the states of the trie so that they are named by a contiguous
        range of integers and those integers respect the topological ordering
        of the trie. This improves the efficiency of the updating the trie as
        it improves memory locality.
        """
        self.ordering = np.array(list(self._order(self.root)), np.int32)
        ordering = {}
        for i, x in enumerate(self._order_full(self.root)):
            ordering[x] = i
        self._rename(f=lambda x: ordering[x])

    def _order(self, node):
        """Generate a topological ordering of nodes beneath the given node.

        Args:
            node (int): Starting node index

        Yields:
            int: Node indices in topological order
        """
        for a in self.children[node]:
            if a is not None:
                yield from self._order(self.children[node][a])
        yield node

    def _order_full(self, node):
        """Generate a complete topological ordering including all child nodes.

        Args:
            node (int): Starting node index

        Yields:
            (int): Node indices in complete topological order
        """
        for a in self.children[node]:
            yield from self._order_full(self.children[node][a])
        yield node

    def _rename(self, f):
        """Rename all node indices in the trie using the provided mapping function.

        Args:
            f (callable): Function that maps old node indices to new node indices
        """
        N = len(self.children)

        new_children = [{} for _ in range(N)]
        nodes = range(N)

        for x in nodes:
            for letter, y in self.children[x].items():
                new_children[f(x)][letter] = f(y)

        self.root = f(self.root)
        self.children = new_children
        self.word2leaf = {w: f(x) for w, x in self.word2leaf.items()}
        self.leaf2word = dict(zip(self.word2leaf.values(), self.word2leaf.keys()))

        self.token_id_to_leaf = np.array(
            [(i, f(x)) for i, x in self.token_id_to_leaf], dtype=np.int32
        )
        self.leaf2token_id = dict(
            zip(self.token_id_to_leaf[:, 1], self.token_id_to_leaf[:, 0])
        )

        self.ordering = np.array([f(x) for x in self.ordering])
        self.jump = [np.array(sorted(x.values()), dtype=np.int32) for x in new_children]

    def _build_node2prefix(self):
        """Builds a mapping from each node to its prefix.

        Returns:
            (dict): A dictionary where keys are node IDs and values are lists of characters.
        """
        node2prefix = {self.root: []}
        for x in reversed(range(len(self.children))):
            for letter, y in self.children[x].items():
                if letter is None:
                    node2prefix[y] = node2prefix[x]
                elif isinstance(letter, bytes):
                    node2prefix[y] = node2prefix[x] + list(letter)
                else:
                    node2prefix[y] = node2prefix[x] + [letter]

        self.node2prefix = node2prefix

    def _build_parent_map(self):
        """Builds a mapping from each node to its parent node in the trie.

        Returns:
            (dict): A dictionary where keys are child nodes and values are their parent nodes.
        """
        parent = {}
        for node in range(len(self.children)):
            for child in self.jump[node]:
                parent[child] = node
        return parent

    def _build_reachability_matrix(self):
        """Constructs a sparse reachability matrix for efficient weight propagation.

        The matrix M is constructed such that M[i,j] = 1 if node j is either:
        - The leaf node i itself (self-connection)
        - An ancestor of leaf node i in the trie
        """
        leaf_indices = self.token_id_to_leaf[:, 1]
        parent = self._build_parent_map()

        rows, cols = [], []
        for i, node in enumerate(leaf_indices):
            # self connections
            rows.append(i)
            cols.append(node)

            current = node
            while current in parent:  # Walk up to root
                ancestor = parent[current]
                rows.append(i)
                cols.append(ancestor)
                current = ancestor

        self.src_indices = torch.tensor(rows, dtype=torch.long, device=self.device)
        self.dst_indices = torch.tensor(cols, dtype=torch.long, device=self.device)

        indices = torch.tensor([rows, cols], dtype=torch.long, device=self.device)
        values = torch.ones(len(rows), device=self.device)

        self.M = torch.sparse_coo_tensor(
            indices, values, (len(leaf_indices), len(self.children))
        ).to_sparse_csr()

    def _preprocess_ws(self, batch_ws):
        """Preprocess weight sums for batch processing.

        Args:
            batch_ws (list|np.ndarray|torch.Tensor): List of weight sum tensors or lists of weight sums.

        Returns:
            (torch.Tensor): Stacked weight sum tensor.
        """
        processed_batch_ws = []
        for ws in batch_ws:
            if not isinstance(ws, torch.Tensor):
                ws = torch.tensor(ws, device=self.device, dtype=torch.float32)
            elif ws.device != self.device or ws.dtype != torch.float32:
                ws = ws.to(device=self.device, dtype=torch.float32)
            assert ws.shape[0] == len(self.decode), [ws.shape[0], len(self.decode)]
            processed_batch_ws.append(ws)
        return torch.stack(processed_batch_ws)

    def weight_sum(self, ws):
        """Computes the sum of weights of all leaf nodes (tokens) that are descendants of each node in the trie.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).

        Returns:
            (numpy.ndarray): Summed weights for each node in the trie, shape (num_nodes,).
        """
        return self.batch_weight_sum(self._preprocess_ws([ws]))[0]

    def batch_weight_sum(self, ws):
        """Batch version of `weight_sum`.

        Args:
            ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).

        Returns:
            (numpy.ndarray): Summed weights for each node in the trie, shape (batch_size × num_nodes).
        """
        ws = self._preprocess_ws(ws)
        batch_size = ws.shape[0]
        all_masses = []
        # If you are getting illegal memory access errors here,
        # try reducing the max_batch_size.
        for i in range(0, batch_size, self.max_batch_size):
            batch_ws = ws[i : i + self.max_batch_size]
            masses = torch.sparse.mm(batch_ws[:, self.token_ids], self.M)
            all_masses.append(masses)
        return torch.cat(all_masses, dim=0)

    def weight_max(self, ws):
        """Computes the maximum weight of all descendant leaf nodes (tokens) for each node in the trie.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).

        Returns:
            (numpy.ndarray): Maximum weights for each node in the trie, shape (num_nodes,).
        """
        return self.batch_weight_max(self._preprocess_ws([ws]))[0]

    def batch_weight_max(self, ws):
        """Batch version of `weight_max`.

        Args:
            ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).

        Returns:
            (numpy.ndarray): Maximum weights for each node in the trie, shape (batch_size × num_nodes).
        """
        ws = self._preprocess_ws(ws)

        # Get leaf weights
        leaf_weights = ws[:, self.token_ids]  # shape: (batch_size × num_leafs)
        batch_size = leaf_weights.shape[0]

        # Use scatter_reduce to propagate maximum values in parallel
        result = torch.zeros((batch_size, len(self.children)), device=self.device)
        result.scatter_reduce_(
            dim=1,
            index=self.dst_indices.expand(batch_size, -1),
            src=leaf_weights[:, self.src_indices],
            reduce="amax",
            include_self=False,
        )

        return result

    def visualize(self, ws=None):
        """Visualize the trie structure using Graphviz.

        Args:
            ws (np.ndarray|None): Optional weight vector to display at each node. Should be of length `len(self.children)`.

        Returns:
            (graphviz.Digraph): The generated graph object
        """
        try:
            import graphviz
        except ImportError:  # pragma: no cover
            raise ImportError(
                "Please install graphviz: pip install graphviz"
            )  # pragma: no cover

        if ws is not None and len(ws) != len(self.children):
            raise ValueError(
                f"Weight vector length ({len(ws)}) must match number of nodes ({len(self.children)})"
            )

        dot = graphviz.Digraph(comment="Token Character Trie")
        dot.attr(rankdir="LR")

        # Create a subgraph for the legend
        with dot.subgraph(name="cluster_legend") as legend:
            legend.attr(label="Legend", fontsize="10")
            legend.attr("node", fontsize="7", width="0.1", height="0.1")

            # Example internal node
            legend.node(
                "legend_internal",
                "Internal Node ID\n'Prefix'\nWeight (if provided)",
                shape="circle",
            )

            # Example leaf node
            legend.node("legend_leaf", "Complete Token", shape="doublecircle")

            legend.edge(
                "legend_internal",
                "legend_leaf",
                label="Token item",
                fontsize="10",
            )

            # Align legend horizontally
            legend.attr(rankdir="TB")
            legend.attr(rank="same")

        # Add the main trie nodes and edges
        for node_id in range(len(self.children)):
            prefix = self.node2prefix[node_id]

            if ws is not None:
                label = f"{node_id}\n'{prefix}'\n{ws[node_id]:.4f}"
            else:
                label = f"{node_id}\n'{prefix}'"

            # Color nodes based on mass if provided
            if ws is not None:
                max_ws = ws.max()
                if max_ws > 0:
                    intensity = int(255 * (1 - ws[node_id] / max_ws))
                    color = f"#{intensity:02x}{255:02x}{intensity:02x}"
                else:
                    color = "#ffffff"  # white for zero mass
            else:
                color = "#ffffff"  # default white

            if node_id in self.leaf2word:
                dot.node(
                    str(node_id),
                    label,
                    shape="doublecircle",
                    style="filled",
                    fillcolor=color,
                )
            else:
                dot.node(
                    str(node_id), label, shape="circle", style="filled", fillcolor=color
                )

        for node_id, children in enumerate(self.children):
            for char, child_id in children.items():
                if char is not None:
                    edge_label = str(char)
                else:
                    edge_label = "End-of-Token"

                dot.edge(str(node_id), str(child_id), label=edge_label)

        return dot

__init__(decode, device=None, atomic_tokens=None, eot_token=None, max_batch_size=64)

Initialize a TokenByteTrie.

Parameters:

Name Type Description Default
decode list[bytes]

List representing the token vocabulary.

required
device str

Device to use for weight sum and max computations ('cpu' or 'cuda').

None
atomic_tokens list[bytes]

List of tokens that should be treated as atomic units rather than being split into bytes.

None
eot_token bytes | None

End-of-token token. Default is None, which represents EOT as None.

None
max_batch_size int

Maximum batch size for weight sum sparse matrix multiplication.

64
Source code in genlm/bytes/trie.py
def __init__(
    self, decode, device=None, atomic_tokens=None, eot_token=None, max_batch_size=64
):
    """Initialize a `TokenByteTrie`.

    Args:
        decode (list[bytes]): List representing the token vocabulary.
        device (str, optional): Device to use for weight sum and max computations ('cpu' or 'cuda').
        atomic_tokens (list[bytes], optional): List of tokens that should be treated as atomic units rather than being split into bytes.
        eot_token (bytes|None, optional): End-of-token token. Default is None, which represents EOT as None.
        max_batch_size (int, optional): Maximum batch size for weight sum sparse matrix multiplication.
    """
    self.decode = decode
    self.max_batch_size = max_batch_size

    self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    if self.device not in ["cpu", "cuda"]:
        raise ValueError(f"Invalid device: {device}. Must be 'cpu', 'cuda' or None")

    self.eot_token = eot_token
    self._build_trie(atomic_tokens or [])
    self._renumber()
    self._build_node2prefix()
    self._build_reachability_matrix()
    self.token_ids = torch.tensor(
        self.token_id_to_leaf[:, 0], dtype=torch.long, device=self.device
    )

weight_sum(ws)

Computes the sum of weights of all leaf nodes (tokens) that are descendants of each node in the trie.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.decode),).

required

Returns:

Type Description
ndarray

Summed weights for each node in the trie, shape (num_nodes,).

Source code in genlm/bytes/trie.py
def weight_sum(self, ws):
    """Computes the sum of weights of all leaf nodes (tokens) that are descendants of each node in the trie.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).

    Returns:
        (numpy.ndarray): Summed weights for each node in the trie, shape (num_nodes,).
    """
    return self.batch_weight_sum(self._preprocess_ws([ws]))[0]

batch_weight_sum(ws)

Batch version of weight_sum.

Parameters:

Name Type Description Default
ws Tensor

Batch of token weights, shape (batch_size × len(self.decode)).

required

Returns:

Type Description
ndarray

Summed weights for each node in the trie, shape (batch_size × num_nodes).

Source code in genlm/bytes/trie.py
def batch_weight_sum(self, ws):
    """Batch version of `weight_sum`.

    Args:
        ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).

    Returns:
        (numpy.ndarray): Summed weights for each node in the trie, shape (batch_size × num_nodes).
    """
    ws = self._preprocess_ws(ws)
    batch_size = ws.shape[0]
    all_masses = []
    # If you are getting illegal memory access errors here,
    # try reducing the max_batch_size.
    for i in range(0, batch_size, self.max_batch_size):
        batch_ws = ws[i : i + self.max_batch_size]
        masses = torch.sparse.mm(batch_ws[:, self.token_ids], self.M)
        all_masses.append(masses)
    return torch.cat(all_masses, dim=0)

weight_max(ws)

Computes the maximum weight of all descendant leaf nodes (tokens) for each node in the trie.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.decode),).

required

Returns:

Type Description
ndarray

Maximum weights for each node in the trie, shape (num_nodes,).

Source code in genlm/bytes/trie.py
def weight_max(self, ws):
    """Computes the maximum weight of all descendant leaf nodes (tokens) for each node in the trie.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).

    Returns:
        (numpy.ndarray): Maximum weights for each node in the trie, shape (num_nodes,).
    """
    return self.batch_weight_max(self._preprocess_ws([ws]))[0]

batch_weight_max(ws)

Batch version of weight_max.

Parameters:

Name Type Description Default
ws Tensor

Batch of token weights, shape (batch_size × len(self.decode)).

required

Returns:

Type Description
ndarray

Maximum weights for each node in the trie, shape (batch_size × num_nodes).

Source code in genlm/bytes/trie.py
def batch_weight_max(self, ws):
    """Batch version of `weight_max`.

    Args:
        ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).

    Returns:
        (numpy.ndarray): Maximum weights for each node in the trie, shape (batch_size × num_nodes).
    """
    ws = self._preprocess_ws(ws)

    # Get leaf weights
    leaf_weights = ws[:, self.token_ids]  # shape: (batch_size × num_leafs)
    batch_size = leaf_weights.shape[0]

    # Use scatter_reduce to propagate maximum values in parallel
    result = torch.zeros((batch_size, len(self.children)), device=self.device)
    result.scatter_reduce_(
        dim=1,
        index=self.dst_indices.expand(batch_size, -1),
        src=leaf_weights[:, self.src_indices],
        reduce="amax",
        include_self=False,
    )

    return result

visualize(ws=None)

Visualize the trie structure using Graphviz.

Parameters:

Name Type Description Default
ws ndarray | None

Optional weight vector to display at each node. Should be of length len(self.children).

None

Returns:

Type Description
Digraph

The generated graph object

Source code in genlm/bytes/trie.py
def visualize(self, ws=None):
    """Visualize the trie structure using Graphviz.

    Args:
        ws (np.ndarray|None): Optional weight vector to display at each node. Should be of length `len(self.children)`.

    Returns:
        (graphviz.Digraph): The generated graph object
    """
    try:
        import graphviz
    except ImportError:  # pragma: no cover
        raise ImportError(
            "Please install graphviz: pip install graphviz"
        )  # pragma: no cover

    if ws is not None and len(ws) != len(self.children):
        raise ValueError(
            f"Weight vector length ({len(ws)}) must match number of nodes ({len(self.children)})"
        )

    dot = graphviz.Digraph(comment="Token Character Trie")
    dot.attr(rankdir="LR")

    # Create a subgraph for the legend
    with dot.subgraph(name="cluster_legend") as legend:
        legend.attr(label="Legend", fontsize="10")
        legend.attr("node", fontsize="7", width="0.1", height="0.1")

        # Example internal node
        legend.node(
            "legend_internal",
            "Internal Node ID\n'Prefix'\nWeight (if provided)",
            shape="circle",
        )

        # Example leaf node
        legend.node("legend_leaf", "Complete Token", shape="doublecircle")

        legend.edge(
            "legend_internal",
            "legend_leaf",
            label="Token item",
            fontsize="10",
        )

        # Align legend horizontally
        legend.attr(rankdir="TB")
        legend.attr(rank="same")

    # Add the main trie nodes and edges
    for node_id in range(len(self.children)):
        prefix = self.node2prefix[node_id]

        if ws is not None:
            label = f"{node_id}\n'{prefix}'\n{ws[node_id]:.4f}"
        else:
            label = f"{node_id}\n'{prefix}'"

        # Color nodes based on mass if provided
        if ws is not None:
            max_ws = ws.max()
            if max_ws > 0:
                intensity = int(255 * (1 - ws[node_id] / max_ws))
                color = f"#{intensity:02x}{255:02x}{intensity:02x}"
            else:
                color = "#ffffff"  # white for zero mass
        else:
            color = "#ffffff"  # default white

        if node_id in self.leaf2word:
            dot.node(
                str(node_id),
                label,
                shape="doublecircle",
                style="filled",
                fillcolor=color,
            )
        else:
            dot.node(
                str(node_id), label, shape="circle", style="filled", fillcolor=color
            )

    for node_id, children in enumerate(self.children):
        for char, child_id in children.items():
            if char is not None:
                edge_label = str(char)
            else:
                edge_label = "End-of-Token"

            dot.edge(str(node_id), str(child_id), label=edge_label)

    return dot

AsyncTokenByteTrie

An asynchronous wrapper for TokenByteTrie implementations that provides automatic request batching.

Source code in genlm/bytes/trie.py
class AsyncTokenByteTrie:
    """An asynchronous wrapper for TokenByteTrie implementations that provides automatic request batching."""

    def __init__(self, trie):
        """Initialize an `AsyncTokenByteTrie`.

        Args:
            trie (TokenByteTrie): The underlying `TokenByteTrie` instance
        """
        self.trie = trie
        self._queue = None
        self._task = None

    @classmethod
    def from_vocab(cls, vocab, **kwargs):
        """Creates an `AsyncTokenByteTrie` from a vocabulary.

        Args:
            vocab (list): The vocabulary over which the trie will be defined.
            **kwargs (dict): Additional arguments passed to the trie constructor

        Returns:
            (AsyncTokenByteTrie): The initialized asynchronous trie instance.
        """
        trie = TokenByteTrie(decode=vocab, **kwargs)
        return cls(trie)

    def _queue_request(self, request, op):
        if not self._task or self._task.done():
            self.start()

        future = asyncio.get_running_loop().create_future()
        self._queue.put_nowait((request, future, op))
        return future

    async def weight_sum(self, ws):
        """Queue a `weight_sum` request. Multiple concurrent calls will be automatically batched
        together.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).

        Returns:
            (np.ndarray): The calculated mass sums for the given distribution.
        """
        return await self._queue_request(ws, TrieOp.SUM)

    async def weight_max(self, ws):
        """Queue a `weight_max` request. Multiple concurrent calls will be automatically batched
        together.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).

        Returns:
            (np.ndarray): The calculated max weights for the given distribution.
        """
        return await self._queue_request(ws, TrieOp.MAX)

    def start(self):
        """Start the background processing task if not already running."""
        if not self._task or self._task.done():
            logger.debug("starting background loop")
            # Create a new queue so that it is bound to the current event loop
            self._queue = asyncio.Queue()
            self._task = asyncio.create_task(self._background_loop())

    async def _background_loop(self):
        """Background task that processes queued weight sum and max requests.

        Continuously monitors the queue for new requests and processes them in batches
        using the underlying trie implementation.

        Raises:
            (Exception): If any error occurs during processing, it is propagated to all
                         pending futures in the current batch.
        """
        while True:
            try:
                op_groups = defaultdict(list)

                request, future, op = await self._queue.get()
                op_groups[op].append((request, future))

                try:
                    while True:
                        request, future, op = self._queue.get_nowait()
                        op_groups[op].append((request, future))
                except asyncio.QueueEmpty:
                    pass

                for op, group in op_groups.items():
                    requests, futures = zip(*group)

                    if op == TrieOp.SUM:
                        if logger.isEnabledFor(logging.DEBUG):
                            logger.debug(f"processing {len(requests)} sum requests")
                        results = self.trie.batch_weight_sum(requests)
                    elif op == TrieOp.MAX:
                        if logger.isEnabledFor(logging.DEBUG):
                            logger.debug(f"processing {len(requests)} max requests")
                        results = self.trie.batch_weight_max(requests)
                    else:
                        raise ValueError(f"Unknown trie operation: {op}")

                    for future, result in zip(futures, results):
                        future.set_result(result)

            except Exception as e:
                for group in op_groups.values():
                    for _, future in group:
                        if not future.done():
                            future.set_exception(e)
                raise

    async def cleanup(self):
        """Async cleanup - preferred method"""
        if self._task and not self._task.done():
            self._task.cancel()
            try:
                await self._task
            except asyncio.CancelledError:
                pass
            self._task = None

    def shutdown(self):
        """Stop the background processing task and cleanup resources."""
        if self._task is not None:
            try:
                self._task.cancel()
            except RuntimeError:
                # Ignore runtime errors that might occur if event loop is closed
                pass
            self._task = None

    def __del__(self):
        self.shutdown()

__init__(trie)

Initialize an AsyncTokenByteTrie.

Parameters:

Name Type Description Default
trie TokenByteTrie

The underlying TokenByteTrie instance

required
Source code in genlm/bytes/trie.py
def __init__(self, trie):
    """Initialize an `AsyncTokenByteTrie`.

    Args:
        trie (TokenByteTrie): The underlying `TokenByteTrie` instance
    """
    self.trie = trie
    self._queue = None
    self._task = None

from_vocab(vocab, **kwargs) classmethod

Creates an AsyncTokenByteTrie from a vocabulary.

Parameters:

Name Type Description Default
vocab list

The vocabulary over which the trie will be defined.

required
**kwargs dict

Additional arguments passed to the trie constructor

{}

Returns:

Type Description
AsyncTokenByteTrie

The initialized asynchronous trie instance.

Source code in genlm/bytes/trie.py
@classmethod
def from_vocab(cls, vocab, **kwargs):
    """Creates an `AsyncTokenByteTrie` from a vocabulary.

    Args:
        vocab (list): The vocabulary over which the trie will be defined.
        **kwargs (dict): Additional arguments passed to the trie constructor

    Returns:
        (AsyncTokenByteTrie): The initialized asynchronous trie instance.
    """
    trie = TokenByteTrie(decode=vocab, **kwargs)
    return cls(trie)

weight_sum(ws) async

Queue a weight_sum request. Multiple concurrent calls will be automatically batched together.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.trie.decode),).

required

Returns:

Type Description
ndarray

The calculated mass sums for the given distribution.

Source code in genlm/bytes/trie.py
async def weight_sum(self, ws):
    """Queue a `weight_sum` request. Multiple concurrent calls will be automatically batched
    together.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).

    Returns:
        (np.ndarray): The calculated mass sums for the given distribution.
    """
    return await self._queue_request(ws, TrieOp.SUM)

weight_max(ws) async

Queue a weight_max request. Multiple concurrent calls will be automatically batched together.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.trie.decode),).

required

Returns:

Type Description
ndarray

The calculated max weights for the given distribution.

Source code in genlm/bytes/trie.py
async def weight_max(self, ws):
    """Queue a `weight_max` request. Multiple concurrent calls will be automatically batched
    together.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).

    Returns:
        (np.ndarray): The calculated max weights for the given distribution.
    """
    return await self._queue_request(ws, TrieOp.MAX)

start()

Start the background processing task if not already running.

Source code in genlm/bytes/trie.py
def start(self):
    """Start the background processing task if not already running."""
    if not self._task or self._task.done():
        logger.debug("starting background loop")
        # Create a new queue so that it is bound to the current event loop
        self._queue = asyncio.Queue()
        self._task = asyncio.create_task(self._background_loop())

cleanup() async

Async cleanup - preferred method

Source code in genlm/bytes/trie.py
async def cleanup(self):
    """Async cleanup - preferred method"""
    if self._task and not self._task.done():
        self._task.cancel()
        try:
            await self._task
        except asyncio.CancelledError:
            pass
        self._task = None

shutdown()

Stop the background processing task and cleanup resources.

Source code in genlm/bytes/trie.py
def shutdown(self):
    """Stop the background processing task and cleanup resources."""
    if self._task is not None:
        try:
            self._task.cancel()
        except RuntimeError:
            # Ignore runtime errors that might occur if event loop is closed
            pass
        self._task = None

Chart

Bases: dict

A specialized dictionary for managing probability distributions.

Extends dict with operations useful for probability distributions and numeric computations, including arithmetic operations, normalization, and visualization.

Parameters:

Name Type Description Default
zero Any

Default value for missing keys

required
vals tuple

Initial (key, value) pairs

()
Source code in genlm/bytes/util.py
class Chart(dict):
    """A specialized dictionary for managing probability distributions.

    Extends dict with operations useful for probability distributions and numeric computations,
    including arithmetic operations, normalization, and visualization.

    Args:
        zero (Any): Default value for missing keys
        vals (tuple, optional): Initial (key, value) pairs
    """

    def __init__(self, zero, vals=()):
        self.zero = zero
        super().__init__(vals)

    def __missing__(self, k):
        return self.zero

    def spawn(self):
        return Chart(self.zero)

    def __add__(self, other):
        new = self.spawn()
        for k, v in self.items():
            new[k] += v
        for k, v in other.items():
            new[k] += v
        return new

    def __mul__(self, other):
        new = self.spawn()
        for k in self:
            v = self[k] * other[k]
            if v == self.zero:
                continue
            new[k] += v
        return new

    def copy(self):
        return Chart(self.zero, self)

    def trim(self):
        return Chart(self.zero, {k: v for k, v in self.items() if v != self.zero})

    def metric(self, other):
        assert isinstance(other, Chart)
        err = 0
        for x in self.keys() | other.keys():
            err = max(err, abs(self[x] - other[x]))
        return err

    def _repr_html_(self):
        return (
            '<div style="font-family: Monospace;">'
            + format_table(self.trim().items(), headings=["key", "value"])
            + "</div>"
        )

    def __repr__(self):
        return repr({k: v for k, v in self.items() if v != self.zero})

    def __str__(self, style_value=lambda k, v: str(v)):
        def key(k):
            return -self[k]

        return (
            "Chart {\n"
            + "\n".join(
                f"  {k!r}: {style_value(k, self[k])},"
                for k in sorted(self, key=key)
                if self[k] != self.zero
            )
            + "\n}"
        )

    def assert_equal(self, want, *, domain=None, tol=1e-5, verbose=False, throw=True):
        if not isinstance(want, Chart):
            want = Chart(self.zero, want)
        if domain is None:
            domain = self.keys() | want.keys()
        assert verbose or throw
        errors = []
        for x in domain:
            if abs(self[x] - want[x]) <= tol:
                if verbose:
                    print(colors.mark(True), x, self[x])
            else:
                if verbose:
                    print(colors.mark(False), x, self[x], want[x])
                errors.append(x)
        if throw:
            for x in errors:
                raise AssertionError(f"{x}: {self[x]} {want[x]}")

    def argmax(self):
        return max(self, key=self.__getitem__)

    def argmin(self):
        return min(self, key=self.__getitem__)

    def top(self, k):
        return Chart(
            self.zero,
            {k: self[k] for k in sorted(self, key=self.__getitem__, reverse=True)[:k]},
        )

    def max(self):
        return max(self.values())

    def min(self):
        return min(self.values())

    def sum(self):
        return sum(self.values())

    def sort(self, **kwargs):
        return Chart(self.zero, [(k, self[k]) for k in sorted(self, **kwargs)])

    def sort_descending(self):
        return Chart(
            self.zero, [(k, self[k]) for k in sorted(self, key=lambda k: -self[k])]
        )

    def normalize(self):
        Z = self.sum()
        if Z == 0:
            return self
        return Chart(self.zero, [(k, v / Z) for k, v in self.items()])

    def filter(self, f):
        return Chart(self.zero, [(k, v) for k, v in self.items() if f(k)])

    def map_values(self, f):
        return Chart(f(self.zero), [(k, f(v)) for k, v in self.items()])

    def map_keys(self, f):
        return Chart(self.zero, [(f(k), v) for k, v in self.items()])

    def project(self, f):
        "Apply the function `f` to each key; summing when f-transformed keys overlap."
        out = self.spawn()
        for k, v in self.items():
            out[f(k)] += v
        return out

    # TODO: the more general version of this method is join
    def compare(self, other, *, domain=None):
        if not isinstance(other, Chart):
            other = Chart(self.zero, other)
        if domain is None:
            domain = self.keys() | other.keys()
        rows = []
        for x in domain:
            m = abs(self[x] - other[x])
            rows.append(dict(key=x, self=self[x], other=other[x], metric=m))
        return pd.DataFrame(rows)

    def to_dict(self):
        return {k: v for k, v in self.items()}

project(f)

Apply the function f to each key; summing when f-transformed keys overlap.

Source code in genlm/bytes/util.py
def project(self, f):
    "Apply the function `f` to each key; summing when f-transformed keys overlap."
    out = self.spawn()
    for k, v in self.items():
        out[f(k)] += v
    return out