Skip to content

lm_state

genlm.bytes.byte_lm.lm_state

StatefulTokenizedLM

A stateful tokenized language model that maintains context and generates next token logprobs.

Parameters:

Name Type Description Default
model AsyncLM

The underlying language model

required
context list

List of token IDs representing the current context

required
n_calls int

Number of times the model has been called

0
max_context_length int

Maximum length of context to maintain

None
Source code in genlm/bytes/byte_lm/lm_state.py
class StatefulTokenizedLM:
    """A stateful tokenized language model that maintains context and generates next token logprobs.

    Args:
        model (genlm.backend.AsyncLM): The underlying language model
        context (list): List of token IDs representing the current context
        n_calls (int): Number of times the model has been called
        max_context_length (int, optional): Maximum length of context to maintain
    """

    def __init__(self, model, context, n_calls=0, max_context_length=None):
        self.model = model
        self.context = context
        self._n_calls = n_calls
        self.max_context_length = max_context_length

    @classmethod
    def initial(cls, model, initial_context=None, max_context_length=None):
        """Creates an initial state for the language model.

        Args:
            model (genlm.backend.AsyncLM): The language model to use
            initial_context (list, optional): Initial context of token IDs. Defaults to [tokenizer.bos_token_id]
            max_context_length (int, optional): Maximum context length to maintain

        Returns:
            (StatefulTokenizedLM): A new instance with initial state
        """
        if initial_context is None:
            initial_context = [model.tokenizer.bos_token_id]
        return cls(model, initial_context, max_context_length=max_context_length)

    def __lshift__(self, token):
        """Adds a new token to the context and returns a new state.

        Args:
            token (int): Token ID to add to context

        Returns:
            (StatefulTokenizedLM): New state with updated context
        """
        assert isinstance(token, int)
        if (
            self.max_context_length is not None
            and len(self.context) >= self.max_context_length
        ):
            self.context = self.context[-(self.max_context_length - 1) :]
        return StatefulTokenizedLM(
            self.model, self.context + [token], n_calls=self._n_calls
        )

    async def logp_next(self):
        """Computes log probabilities for the next token given the current context.

        Returns:
            (torch.Tensor): Log probabilities for next tokens
        """
        self._n_calls += 1
        return await self.model.next_token_logprobs(self.context)

    def __repr__(self):
        return colors.purple % (
            "|".join([escape(self.model.byte_vocab[x]) for x in self.context])
        )

initial(model, initial_context=None, max_context_length=None) classmethod

Creates an initial state for the language model.

Parameters:

Name Type Description Default
model AsyncLM

The language model to use

required
initial_context list

Initial context of token IDs. Defaults to [tokenizer.bos_token_id]

None
max_context_length int

Maximum context length to maintain

None

Returns:

Type Description
StatefulTokenizedLM

A new instance with initial state

Source code in genlm/bytes/byte_lm/lm_state.py
@classmethod
def initial(cls, model, initial_context=None, max_context_length=None):
    """Creates an initial state for the language model.

    Args:
        model (genlm.backend.AsyncLM): The language model to use
        initial_context (list, optional): Initial context of token IDs. Defaults to [tokenizer.bos_token_id]
        max_context_length (int, optional): Maximum context length to maintain

    Returns:
        (StatefulTokenizedLM): A new instance with initial state
    """
    if initial_context is None:
        initial_context = [model.tokenizer.bos_token_id]
    return cls(model, initial_context, max_context_length=max_context_length)

__lshift__(token)

Adds a new token to the context and returns a new state.

Parameters:

Name Type Description Default
token int

Token ID to add to context

required

Returns:

Type Description
StatefulTokenizedLM

New state with updated context

Source code in genlm/bytes/byte_lm/lm_state.py
def __lshift__(self, token):
    """Adds a new token to the context and returns a new state.

    Args:
        token (int): Token ID to add to context

    Returns:
        (StatefulTokenizedLM): New state with updated context
    """
    assert isinstance(token, int)
    if (
        self.max_context_length is not None
        and len(self.context) >= self.max_context_length
    ):
        self.context = self.context[-(self.max_context_length - 1) :]
    return StatefulTokenizedLM(
        self.model, self.context + [token], n_calls=self._n_calls
    )

logp_next() async

Computes log probabilities for the next token given the current context.

Returns:

Type Description
Tensor

Log probabilities for next tokens

Source code in genlm/bytes/byte_lm/lm_state.py
async def logp_next(self):
    """Computes log probabilities for the next token given the current context.

    Returns:
        (torch.Tensor): Log probabilities for next tokens
    """
    self._n_calls += 1
    return await self.model.next_token_logprobs(self.context)

StatefulByteLM

Bases: ABC

Abstract base class for byte-level language models with state.

Source code in genlm/bytes/byte_lm/lm_state.py
class StatefulByteLM(ABC):
    """Abstract base class for byte-level language models with state."""

    @abstractmethod
    async def __lshift__(self, b: int):
        """Adds a byte to the state and returns new state.

        Args:
            b (int): Byte to add

        Returns:
            (StatefulByteLM): New state with added byte
        """
        pass

    def prune(self):
        """Prunes the current state if needed.

        Override in subclasses.

        Returns:
            (StatefulByteLM): Pruned state
        """
        return self

    @abstractmethod
    async def logp_next(self):
        """Computes the log probability distribution for the next byte.

        Returns:
            (LazyByteProbs): Log probabilities for next possible bytes
        """
        pass

    async def prefill(self, bs):
        """Prefills the model state with a sequence of bytes.

        Args:
            bs (list[int]): Sequence of bytes to add to state

        Returns:
            (StatefulByteLM): New state with all bytes added
        """
        state = self
        for b in bs:
            state = await (state.prune() << b)
        return state

    async def greedy(self, context, steps):
        """Performs greedy decoding for given number of steps.

        Args:
            context (bytes): Initial byte context
            steps (int): Number of generation steps

        Returns:
            (bytes): Generated byte sequence
        """
        context = list(context)
        state = await self.prefill(context)
        for _ in range(steps):
            Q = (await state.logp_next()).materialize()
            b = Q.argmax()
            state = await (state.prune() << b)
            context.append(b)
        return bytes(context)

    async def sample(self, context, steps, draw=sample_dict):
        """Samples from the model for given number of steps.

        Args:
            context (bytes): Initial byte context
            steps (int): Number of generation steps
            draw: Sampling function to use (default: sample_dict)

        Returns:
            (bytes): Generated byte sequence
        """
        context = list(context)
        state = await self.prefill(context)
        for _ in range(steps):
            Q = (await state.logp_next()).materialize()
            b = draw(Q.map_values(exp))
            state = await (state.prune() << b)
            context.append(b)
        return bytes(context)

    async def cleanup(self):
        """Performs any necessary cleanup of the model state."""
        pass

__lshift__(b) abstractmethod async

Adds a byte to the state and returns new state.

Parameters:

Name Type Description Default
b int

Byte to add

required

Returns:

Type Description
StatefulByteLM

New state with added byte

Source code in genlm/bytes/byte_lm/lm_state.py
@abstractmethod
async def __lshift__(self, b: int):
    """Adds a byte to the state and returns new state.

    Args:
        b (int): Byte to add

    Returns:
        (StatefulByteLM): New state with added byte
    """
    pass

prune()

Prunes the current state if needed.

Override in subclasses.

Returns:

Type Description
StatefulByteLM

Pruned state

Source code in genlm/bytes/byte_lm/lm_state.py
def prune(self):
    """Prunes the current state if needed.

    Override in subclasses.

    Returns:
        (StatefulByteLM): Pruned state
    """
    return self

logp_next() abstractmethod async

Computes the log probability distribution for the next byte.

Returns:

Type Description
LazyByteProbs

Log probabilities for next possible bytes

Source code in genlm/bytes/byte_lm/lm_state.py
@abstractmethod
async def logp_next(self):
    """Computes the log probability distribution for the next byte.

    Returns:
        (LazyByteProbs): Log probabilities for next possible bytes
    """
    pass

prefill(bs) async

Prefills the model state with a sequence of bytes.

Parameters:

Name Type Description Default
bs list[int]

Sequence of bytes to add to state

required

Returns:

Type Description
StatefulByteLM

New state with all bytes added

Source code in genlm/bytes/byte_lm/lm_state.py
async def prefill(self, bs):
    """Prefills the model state with a sequence of bytes.

    Args:
        bs (list[int]): Sequence of bytes to add to state

    Returns:
        (StatefulByteLM): New state with all bytes added
    """
    state = self
    for b in bs:
        state = await (state.prune() << b)
    return state

greedy(context, steps) async

Performs greedy decoding for given number of steps.

Parameters:

Name Type Description Default
context bytes

Initial byte context

required
steps int

Number of generation steps

required

Returns:

Type Description
bytes

Generated byte sequence

Source code in genlm/bytes/byte_lm/lm_state.py
async def greedy(self, context, steps):
    """Performs greedy decoding for given number of steps.

    Args:
        context (bytes): Initial byte context
        steps (int): Number of generation steps

    Returns:
        (bytes): Generated byte sequence
    """
    context = list(context)
    state = await self.prefill(context)
    for _ in range(steps):
        Q = (await state.logp_next()).materialize()
        b = Q.argmax()
        state = await (state.prune() << b)
        context.append(b)
    return bytes(context)

sample(context, steps, draw=sample_dict) async

Samples from the model for given number of steps.

Parameters:

Name Type Description Default
context bytes

Initial byte context

required
steps int

Number of generation steps

required
draw

Sampling function to use (default: sample_dict)

sample_dict

Returns:

Type Description
bytes

Generated byte sequence

Source code in genlm/bytes/byte_lm/lm_state.py
async def sample(self, context, steps, draw=sample_dict):
    """Samples from the model for given number of steps.

    Args:
        context (bytes): Initial byte context
        steps (int): Number of generation steps
        draw: Sampling function to use (default: sample_dict)

    Returns:
        (bytes): Generated byte sequence
    """
    context = list(context)
    state = await self.prefill(context)
    for _ in range(steps):
        Q = (await state.logp_next()).materialize()
        b = draw(Q.map_values(exp))
        state = await (state.prune() << b)
        context.append(b)
    return bytes(context)

cleanup() async

Performs any necessary cleanup of the model state.

Source code in genlm/bytes/byte_lm/lm_state.py
async def cleanup(self):
    """Performs any necessary cleanup of the model state."""
    pass