Skip to content

control

Potential

Bases: ABC, PotentialOps, PotentialTests

Abstract base class for potentials.

A Potential is a function that maps sequences of tokens in a vocabulary to non-negative real numbers (weights).

Potentials assign weights to sequences of tokens based on whether they are complete sequences or prefixes of complete sequences.

  • complete: Assess the log weight of a sequence of tokens in the vocabulary as a complete sequence.
  • prefix: Assess the log weight of a sequence of tokens in the vocabulary as a prefix.

Potentials additionally implement a logw_next method:

  • logw_next: Compute the next-token log weights of each token in the vocabulary and a special EOS (end-of-sequence) token given a context.

Subclasses must minimally implement complete and prefix. logw_next and batched versions of the above methods come with default implementations, but may be overridden by subclasses for improved performance.

All Potentials must satisfy a set of properties which can be tested using PotentialTests.

Attributes:

Name Type Description
token_type TokenType

The type of tokens in the vocabulary.

vocab list

List of tokens making up the vocabulary.

eos EndOfSequence

Special token to use as end-of-sequence.

vocab_eos list

List of tokens in vocab and eos. eos is assumed to be the last token in vocab_eos.

lookup dict

Mapping from tokens and eos to their indices in vocab_eos.

Source code in genlm/control/potential/base.py
class Potential(ABC, PotentialOps, PotentialTests):
    """Abstract base class for potentials.

    A Potential is a function that maps sequences of tokens in a vocabulary to non-negative real numbers (weights).

    Potentials assign weights to sequences of tokens based on whether they are complete sequences or prefixes of complete sequences.

    - `complete`: Assess the log weight of a sequence of tokens in the vocabulary as a complete sequence.
    - `prefix`: Assess the log weight of a sequence of tokens in the vocabulary as a prefix.

    Potentials additionally implement a `logw_next` method:

    - `logw_next`: Compute the next-token log weights of each token in the vocabulary and a special EOS (end-of-sequence) token given a context.

    Subclasses must minimally implement `complete` and `prefix`. `logw_next` and batched versions of the above methods
    come with default implementations, but may be overridden by subclasses for improved performance.

    All Potentials must satisfy a set of properties which can be tested using [PotentialTests][genlm.control.potential.testing.PotentialTests].

    Attributes:
        token_type (TokenType): The type of tokens in the vocabulary.
        vocab (list): List of tokens making up the vocabulary.
        eos (EndOfSequence): Special token to use as end-of-sequence.
        vocab_eos (list): List of tokens in `vocab` and `eos`. `eos` is assumed to be the last token in `vocab_eos`.
        lookup (dict): Mapping from tokens and `eos` to their indices in `vocab_eos`.
    """

    def __init__(self, vocabulary, token_type=None, eos=None):
        """
        Initialize the potential.

        Args:
            vocabulary (list): List of tokens that make up the vocabulary.
            token_type (TokenType, optional): Optional TokenType of all elements of the vocabulary.
                If None, will be inferred from vocabulary.
            eos (EndOfSequence, optional): Special token to use as end-of-sequence. Defaults to `EOS`.
                In general, this should not be set by users.

        Raises:
            ValueError: If vocabulary is empty.
            TypeError: If vocabulary contains tokens which are not of `token_type`.
        """
        if not vocabulary:
            raise ValueError("vocabulary cannot be empty")

        if token_type is None:
            token_type = infer_vocabulary_type(vocabulary)
        elif not isinstance(token_type, TokenType):
            raise ValueError(f"token_type must be a TokenType, got {token_type!r}.")

        if not all(token_type.check(x) for x in vocabulary):
            raise TypeError(f"Tokens in vocabulary must be of type {token_type}.")

        if eos and not isinstance(eos, EndOfSequence):
            raise ValueError(f"EOS must be an instance of EndOfSequence, got {eos!r}.")

        self.eos = eos or EOS

        self.token_type = token_type
        self.vocab = vocabulary
        self.vocab_eos = self.vocab + [self.eos]
        self.lookup = {}
        for i, x in enumerate(vocabulary):
            if x in self.lookup:
                raise ValueError(f"Duplicate token {x!r} found in vocabulary")
            self.lookup[x] = i
        self.lookup[self.eos] = len(self.vocab)

    ####################
    # Instance methods #
    ####################

    @abstractmethod
    async def complete(self, context):
        """Assess the weight of `context` as a complete sequence.

        Args:
            context (list): Sequence of tokens.

        Returns:
            (float): Log weight of the context under the language.
        """
        pass  # pragma: no cover

    @abstractmethod
    async def prefix(self, context):
        """Assess the weight of `context` as a prefix.

        Args:
            context (list): Sequence of tokens.

        Returns:
            (float): Log weight of the context as a prefix.
        """
        pass  # pragma: no cover

    async def score(self, context):
        """Assess the weight of `context` based on EOS-termination.

        This is a convenience method which dispatches to `complete` if `context` ends with `self.eos`, otherwise to `prefix`.

        Args:
            context (list): Sequence of tokens.

        Returns:
            (float): Log weight of the context, either as a prefix or complete sequence.
        """
        if context and context[-1] == self.eos:
            return await self.complete(context[:-1])
        else:
            return await self.prefix(context)

    async def logw_next(self, context):
        """Compute the next-token weights of each token in `self.vocab_eos` given `context`.

        Args:
            context (list): Sequence of tokens.

        Returns:
            (LazyWeights): Weights of each token in the vocabulary and EOS.
        """
        ctx_log_w = await self.prefix(context)

        if ctx_log_w == float("-inf"):
            raise ValueError(f"Context {context!r} has weight zero under `prefix`.")

        scores = await self.batch_score([[*context, x] for x in self.vocab_eos])
        logws = scores - ctx_log_w

        return self.make_lazy_weights(logws)

    ###################
    # Batched methods #
    ###################

    async def batch_complete(self, contexts):
        """Batched equivalent to `complete`.

        Assess the weight of each context as a complete sequence.

        Args:
            contexts (list): List of sequences of tokens.

        Returns:
            (np.array): Array of log weights for each context.
        """
        if not contexts:
            raise ValueError("Contexts must be non-empty.")

        return np.array(
            await asyncio.gather(*[self.complete(context) for context in contexts])
        )

    async def batch_prefix(self, contexts):
        """Batched equivalent to `prefix`.

        Assess the weight of each context as a prefix.

        Args:
            contexts (list): List of sequences of tokens.

        Returns:
            (np.array): Array of log weights for each context.
        """
        if not contexts:
            raise ValueError("Contexts must be non-empty.")

        return np.array(
            await asyncio.gather(*[self.prefix(context) for context in contexts])
        )

    async def batch_score(self, contexts):
        """Batched equivalent to `score`.

        Assess the weight of each context based on EOS-termination.

        Args:
            contexts (list): List of sequences of tokens.

        Returns:
            (np.array): Array of log weights for each context.
        """
        if not contexts:
            raise ValueError("Contexts must be non-empty.")

        complete, prefix = [], []
        complete_indices, prefix_indices = [], []

        for i, context in enumerate(contexts):
            # We want == here instead of `is`.
            if context and context[-1] == self.eos:
                complete.append(context[:-1])
                complete_indices.append(i)
            else:
                prefix.append(context)
                prefix_indices.append(i)

        complete_scores = (
            await self.batch_complete(complete) if complete else np.array([])
        )
        prefix_scores = await self.batch_prefix(prefix) if prefix else np.array([])

        results = np.empty(len(contexts))
        if len(complete_scores) > 0:
            results[complete_indices] = complete_scores
        if len(prefix_scores) > 0:
            results[prefix_indices] = prefix_scores

        return results

    async def batch_logw_next(self, contexts):
        """Batched equivalent to `logw_next`.

        Computes the next-token weights of each token in `self.vocab_eos` given each context in the batch.

        Args:
            contexts (list): List of sequences of tokens.

        Returns:
            (list): List of LazyWeights objects, one for each context.

        Raises:
            ValueError: If any context has zero weight (log weight of -inf) under `prefix`.
        """
        if not contexts:
            raise ValueError("Contexts must be non-empty.")

        return await asyncio.gather(*[self.logw_next(context) for context in contexts])

    #############
    # Utilities #
    #############

    def make_lazy_weights(self, weights, log=True):
        """Helper method to create a LazyWeights object over the potential's vocabulary and EOS.

        Args:
            weights (np.array): Array of weights.
            log (bool, optional): Whether the weights are in log space. Defaults to True.

        Returns:
            (LazyWeights): LazyWeights object defined over `self.vocab_eos`.
        """
        return LazyWeights(
            weights=weights, encode=self.lookup, decode=self.vocab_eos, log=log
        )

    def alloc_logws(self, default=float("-inf")):
        """Allocate a new array of log weights for the potential's vocabulary and EOS.

        Args:
            default (float, optional): Default log weight. Defaults to -inf.

        Returns:
            (np.array): Array of length `len(self.vocab_eos)` filled with `default`.
        """
        return np.full((len(self.vocab_eos),), default)

    def spawn(self):
        """
        Spawn a fresh instance of the potential.

        This method is not required by default, but may be implemented by subclasses
        to support CPU-parallelism using (`MultiProcPotential`)[genlm.control.potential.multi_proc.MultiProcPotential].
        """
        raise NotImplementedError(
            "Potential.spawn() must be implemented by subclasses."
        )

    async def cleanup(self):
        """
        Cleanup the potential.

        This method may be implemented by subclasses to release resources.
        """
        pass

__init__(vocabulary, token_type=None, eos=None)

Initialize the potential.

Parameters:

Name Type Description Default
vocabulary list

List of tokens that make up the vocabulary.

required
token_type TokenType

Optional TokenType of all elements of the vocabulary. If None, will be inferred from vocabulary.

None
eos EndOfSequence

Special token to use as end-of-sequence. Defaults to EOS. In general, this should not be set by users.

None

Raises:

Type Description
ValueError

If vocabulary is empty.

TypeError

If vocabulary contains tokens which are not of token_type.

Source code in genlm/control/potential/base.py
def __init__(self, vocabulary, token_type=None, eos=None):
    """
    Initialize the potential.

    Args:
        vocabulary (list): List of tokens that make up the vocabulary.
        token_type (TokenType, optional): Optional TokenType of all elements of the vocabulary.
            If None, will be inferred from vocabulary.
        eos (EndOfSequence, optional): Special token to use as end-of-sequence. Defaults to `EOS`.
            In general, this should not be set by users.

    Raises:
        ValueError: If vocabulary is empty.
        TypeError: If vocabulary contains tokens which are not of `token_type`.
    """
    if not vocabulary:
        raise ValueError("vocabulary cannot be empty")

    if token_type is None:
        token_type = infer_vocabulary_type(vocabulary)
    elif not isinstance(token_type, TokenType):
        raise ValueError(f"token_type must be a TokenType, got {token_type!r}.")

    if not all(token_type.check(x) for x in vocabulary):
        raise TypeError(f"Tokens in vocabulary must be of type {token_type}.")

    if eos and not isinstance(eos, EndOfSequence):
        raise ValueError(f"EOS must be an instance of EndOfSequence, got {eos!r}.")

    self.eos = eos or EOS

    self.token_type = token_type
    self.vocab = vocabulary
    self.vocab_eos = self.vocab + [self.eos]
    self.lookup = {}
    for i, x in enumerate(vocabulary):
        if x in self.lookup:
            raise ValueError(f"Duplicate token {x!r} found in vocabulary")
        self.lookup[x] = i
    self.lookup[self.eos] = len(self.vocab)

complete(context) abstractmethod async

Assess the weight of context as a complete sequence.

Parameters:

Name Type Description Default
context list

Sequence of tokens.

required

Returns:

Type Description
float

Log weight of the context under the language.

Source code in genlm/control/potential/base.py
@abstractmethod
async def complete(self, context):
    """Assess the weight of `context` as a complete sequence.

    Args:
        context (list): Sequence of tokens.

    Returns:
        (float): Log weight of the context under the language.
    """
    pass  # pragma: no cover

prefix(context) abstractmethod async

Assess the weight of context as a prefix.

Parameters:

Name Type Description Default
context list

Sequence of tokens.

required

Returns:

Type Description
float

Log weight of the context as a prefix.

Source code in genlm/control/potential/base.py
@abstractmethod
async def prefix(self, context):
    """Assess the weight of `context` as a prefix.

    Args:
        context (list): Sequence of tokens.

    Returns:
        (float): Log weight of the context as a prefix.
    """
    pass  # pragma: no cover

score(context) async

Assess the weight of context based on EOS-termination.

This is a convenience method which dispatches to complete if context ends with self.eos, otherwise to prefix.

Parameters:

Name Type Description Default
context list

Sequence of tokens.

required

Returns:

Type Description
float

Log weight of the context, either as a prefix or complete sequence.

Source code in genlm/control/potential/base.py
async def score(self, context):
    """Assess the weight of `context` based on EOS-termination.

    This is a convenience method which dispatches to `complete` if `context` ends with `self.eos`, otherwise to `prefix`.

    Args:
        context (list): Sequence of tokens.

    Returns:
        (float): Log weight of the context, either as a prefix or complete sequence.
    """
    if context and context[-1] == self.eos:
        return await self.complete(context[:-1])
    else:
        return await self.prefix(context)

logw_next(context) async

Compute the next-token weights of each token in self.vocab_eos given context.

Parameters:

Name Type Description Default
context list

Sequence of tokens.

required

Returns:

Type Description
LazyWeights

Weights of each token in the vocabulary and EOS.

Source code in genlm/control/potential/base.py
async def logw_next(self, context):
    """Compute the next-token weights of each token in `self.vocab_eos` given `context`.

    Args:
        context (list): Sequence of tokens.

    Returns:
        (LazyWeights): Weights of each token in the vocabulary and EOS.
    """
    ctx_log_w = await self.prefix(context)

    if ctx_log_w == float("-inf"):
        raise ValueError(f"Context {context!r} has weight zero under `prefix`.")

    scores = await self.batch_score([[*context, x] for x in self.vocab_eos])
    logws = scores - ctx_log_w

    return self.make_lazy_weights(logws)

batch_complete(contexts) async

Batched equivalent to complete.

Assess the weight of each context as a complete sequence.

Parameters:

Name Type Description Default
contexts list

List of sequences of tokens.

required

Returns:

Type Description
array

Array of log weights for each context.

Source code in genlm/control/potential/base.py
async def batch_complete(self, contexts):
    """Batched equivalent to `complete`.

    Assess the weight of each context as a complete sequence.

    Args:
        contexts (list): List of sequences of tokens.

    Returns:
        (np.array): Array of log weights for each context.
    """
    if not contexts:
        raise ValueError("Contexts must be non-empty.")

    return np.array(
        await asyncio.gather(*[self.complete(context) for context in contexts])
    )

batch_prefix(contexts) async

Batched equivalent to prefix.

Assess the weight of each context as a prefix.

Parameters:

Name Type Description Default
contexts list

List of sequences of tokens.

required

Returns:

Type Description
array

Array of log weights for each context.

Source code in genlm/control/potential/base.py
async def batch_prefix(self, contexts):
    """Batched equivalent to `prefix`.

    Assess the weight of each context as a prefix.

    Args:
        contexts (list): List of sequences of tokens.

    Returns:
        (np.array): Array of log weights for each context.
    """
    if not contexts:
        raise ValueError("Contexts must be non-empty.")

    return np.array(
        await asyncio.gather(*[self.prefix(context) for context in contexts])
    )

batch_score(contexts) async

Batched equivalent to score.

Assess the weight of each context based on EOS-termination.

Parameters:

Name Type Description Default
contexts list

List of sequences of tokens.

required

Returns:

Type Description
array

Array of log weights for each context.

Source code in genlm/control/potential/base.py
async def batch_score(self, contexts):
    """Batched equivalent to `score`.

    Assess the weight of each context based on EOS-termination.

    Args:
        contexts (list): List of sequences of tokens.

    Returns:
        (np.array): Array of log weights for each context.
    """
    if not contexts:
        raise ValueError("Contexts must be non-empty.")

    complete, prefix = [], []
    complete_indices, prefix_indices = [], []

    for i, context in enumerate(contexts):
        # We want == here instead of `is`.
        if context and context[-1] == self.eos:
            complete.append(context[:-1])
            complete_indices.append(i)
        else:
            prefix.append(context)
            prefix_indices.append(i)

    complete_scores = (
        await self.batch_complete(complete) if complete else np.array([])
    )
    prefix_scores = await self.batch_prefix(prefix) if prefix else np.array([])

    results = np.empty(len(contexts))
    if len(complete_scores) > 0:
        results[complete_indices] = complete_scores
    if len(prefix_scores) > 0:
        results[prefix_indices] = prefix_scores

    return results

batch_logw_next(contexts) async

Batched equivalent to logw_next.

Computes the next-token weights of each token in self.vocab_eos given each context in the batch.

Parameters:

Name Type Description Default
contexts list

List of sequences of tokens.

required

Returns:

Type Description
list

List of LazyWeights objects, one for each context.

Raises:

Type Description
ValueError

If any context has zero weight (log weight of -inf) under prefix.

Source code in genlm/control/potential/base.py
async def batch_logw_next(self, contexts):
    """Batched equivalent to `logw_next`.

    Computes the next-token weights of each token in `self.vocab_eos` given each context in the batch.

    Args:
        contexts (list): List of sequences of tokens.

    Returns:
        (list): List of LazyWeights objects, one for each context.

    Raises:
        ValueError: If any context has zero weight (log weight of -inf) under `prefix`.
    """
    if not contexts:
        raise ValueError("Contexts must be non-empty.")

    return await asyncio.gather(*[self.logw_next(context) for context in contexts])

make_lazy_weights(weights, log=True)

Helper method to create a LazyWeights object over the potential's vocabulary and EOS.

Parameters:

Name Type Description Default
weights array

Array of weights.

required
log bool

Whether the weights are in log space. Defaults to True.

True

Returns:

Type Description
LazyWeights

LazyWeights object defined over self.vocab_eos.

Source code in genlm/control/potential/base.py
def make_lazy_weights(self, weights, log=True):
    """Helper method to create a LazyWeights object over the potential's vocabulary and EOS.

    Args:
        weights (np.array): Array of weights.
        log (bool, optional): Whether the weights are in log space. Defaults to True.

    Returns:
        (LazyWeights): LazyWeights object defined over `self.vocab_eos`.
    """
    return LazyWeights(
        weights=weights, encode=self.lookup, decode=self.vocab_eos, log=log
    )

alloc_logws(default=float('-inf'))

Allocate a new array of log weights for the potential's vocabulary and EOS.

Parameters:

Name Type Description Default
default float

Default log weight. Defaults to -inf.

float('-inf')

Returns:

Type Description
array

Array of length len(self.vocab_eos) filled with default.

Source code in genlm/control/potential/base.py
def alloc_logws(self, default=float("-inf")):
    """Allocate a new array of log weights for the potential's vocabulary and EOS.

    Args:
        default (float, optional): Default log weight. Defaults to -inf.

    Returns:
        (np.array): Array of length `len(self.vocab_eos)` filled with `default`.
    """
    return np.full((len(self.vocab_eos),), default)

spawn()

Spawn a fresh instance of the potential.

This method is not required by default, but may be implemented by subclasses to support CPU-parallelism using (MultiProcPotential)[genlm.control.potential.multi_proc.MultiProcPotential].

Source code in genlm/control/potential/base.py
def spawn(self):
    """
    Spawn a fresh instance of the potential.

    This method is not required by default, but may be implemented by subclasses
    to support CPU-parallelism using (`MultiProcPotential`)[genlm.control.potential.multi_proc.MultiProcPotential].
    """
    raise NotImplementedError(
        "Potential.spawn() must be implemented by subclasses."
    )

cleanup() async

Cleanup the potential.

This method may be implemented by subclasses to release resources.

Source code in genlm/control/potential/base.py
async def cleanup(self):
    """
    Cleanup the potential.

    This method may be implemented by subclasses to release resources.
    """
    pass

PromptedLLM

Bases: Potential

A potential representing a language model conditioned on a fixed prompt prefix.

PromptedLLMs operate on byte sequences.

Notes on EOS Token Handling:

  • Tokens to treat as end-of-sequence tokens are specified via the eos_tokens argument.

  • These tokens are excluded from the potential's vocabulary and as such do not appear in the vocab attribute.

    This means they cannot appear in any input contexts to the potential nor in the output of logw_next. They can be used in the prompt however.

  • The log probability assigned to the genlm.control's reserved EOS token is the sum of the log probabilities of all the specified EOS tokens.

This class wraps an AsyncLM instance.

Source code in genlm/control/potential/built_in/llm.py
class PromptedLLM(Potential):
    """A potential representing a language model conditioned on a fixed prompt prefix.

    `PromptedLLM`s operate on byte sequences.

    Notes on EOS Token Handling:\n
    - Tokens to treat as end-of-sequence tokens are specified via the `eos_tokens` argument.\n
    - These tokens are excluded from the potential's vocabulary and as such do not appear in the `vocab` attribute.\n
        This means they cannot appear in any input contexts to the potential nor in the output of `logw_next`. They can be used in the prompt however.\n
    - The log probability assigned to the `genlm.control`'s reserved `EOS` token is the sum of the log probabilities of all the specified EOS tokens.\n

    This class wraps an `AsyncLM` instance.
    """

    def __init__(
        self,
        llm,
        prompt_ids=None,
        eos_tokens=None,
        temperature=1,
        token_maps=None,
    ):
        """`
        Initializes the PromptedLLM potential.

        Args:
            llm (AsyncLM): The language model to use.
            prompt_ids (list[int], optional): Optional prompt to use as a prompt prefix for all input contexts.
                Must be a list of token IDs. Defaults to None. The prompt ids can be set post-init via `prompt` or `prompt_ids`.
            eos_tokens (list[bytes], optional): List of tokens to treat as end-of-sequence tokens.
                Defaults to the EOS token of the language model's tokenizer.
            temperature (float, optional): The temperature to apply to the language model's logits. Defaults to 1.
            token_maps (TokenMappings, optional): A precomputed mapping of tokens to token IDs with the potential's vocabulary.
                If provided, `eos_tokens` must not be provided. Defaults to None, which constructs a TokenMappings from the language model's byte vocabulary and the EOS tokens.
        """
        self.model = llm
        self.prompt_ids = prompt_ids or []
        self.temperature = temperature

        if token_maps is not None:
            if eos_tokens is not None:
                raise ValueError(
                    "eos_tokens must not be provided when token_maps is provided."
                )
            self.token_maps = token_maps
        else:
            self.token_maps = TokenMappings.create(
                decode=self.model.byte_vocab,
                eos_tokens=eos_tokens
                or [self.model.byte_vocab[self.model.tokenizer.eos_token_id]],
            )

        super().__init__(vocabulary=self.token_maps.potential_vocab)

    @classmethod
    def from_name(
        cls,
        name,
        backend=None,
        eos_tokens=None,
        prompt_ids=None,
        temperature=1.0,
        **kwargs,
    ):
        """Create a `PromptedLLM` from a HugginFace model name.

        Args:
            name (str): Name of the model to load
            backend (str, optional): `AsyncLM` backend to use:\n
                * 'vllm' to instantiate an `AsyncVirtualLM`; ideal for GPU usage\n
                * 'hf' for an `AsyncTransformer`; ideal for CPU usage\n
                * 'mock' for a `MockAsyncLM`; ideal for testing.\n
                Defaults to 'vllm' if CUDA is available, otherwise 'hf'.
            eos_tokens (list[bytes], optional): List of tokens to treat as end-of-sequence tokens.
                Defaults to the EOS token of the language model's tokenizer.
            prompt_ids (list[int], optional): Optional prompt to use as a prompt prefix for all input contexts.
                Must be a list of token IDs. Defaults to None. The prompt ids can be set post-init via `set_prompt_from_str` or `prompt_ids`.
            temperature (float, optional): The temperature to apply to the language model's logits. Defaults to 1.
            **kwargs (dict): Additional arguments passed to AsyncLM constructor

        Returns:
            (PromptedLLM): An instance of PromptedLLM
        """
        backend = backend or ("vllm" if torch.cuda.is_available() else "hf")
        model = load_model_by_name(name, backend=backend, **kwargs)
        return cls(
            model, prompt_ids=prompt_ids, eos_tokens=eos_tokens, temperature=temperature
        )

    @property
    def eos_tokens(self):
        return self.token_maps.eos_tokens

    @eos_tokens.setter
    def eos_tokens(self, value):
        raise ValueError(
            "Cannot reset eos_tokens after initialization. "
            "Use spawn_new_eos(new_eos_tokens) instead."
        )

    @property
    def prompt(self):
        """
        Get the current prompt as a list of byte sequences corresponding to the prompt token IDs.

        Returns:
            (list[bytes]|None): The current prompt as a list of bytes sequences or None if no prompt_ids are set.
        """
        if not self.prompt_ids:
            return  # pragma: no cover
        return [self.token_maps.decode[x] for x in self.prompt_ids]

    def set_prompt_from_str(self, prompt_str):
        """Set the fixed prompt from a string.

        Modifies `prompt_ids` to be the token IDs of the input prompt according to the language model's tokenizer.

        Args:
            prompt_str (str): The prompt to set.
        """
        # TODO: Handle race condition where prompt_ids reset concurrently.
        if not isinstance(prompt_str, str):
            raise ValueError(
                f"Prompt must a string got {type(prompt_str)}. "
                f"To set the prompt from a list of token IDs, use prompt_ids."
            )

        if prompt_str.endswith(" "):
            warnings.warn(
                "Prompt ends with whitespace, which may affect tokenization. "
                "Consider removing trailing whitespace.",
                stacklevel=2,
            )

        self.prompt_ids = self.model.tokenizer.encode(prompt_str)

    def encode_tokens(self, tokens):
        """Encode a list of byte tokens to a list of token IDs in
        the underlying language model's vocabulary.

        Args:
            tokens (list[bytes]): List of byte tokens to encode

        Returns:
            (list[int]): A list of token IDs corresponding to the input tokens.

        Raises:
            ValueError: If any token is not in the vocabulary
        """
        try:
            return [self.token_maps.encode[x] for x in tokens]
        except KeyError as e:
            raise ValueError(f"Token {e.args[0]} not in vocabulary") from e

    def decode_tokens(self, ids):
        """
        Decode a list of token IDs in the language model's vocabulary to a list of byte tokens.

        Args:
            ids (list[int]): A list of token IDs in the language model's vocabulary.

        Returns:
            (list[bytes]): A list of byte tokens corresponding to the input token IDs.
        """
        return [self.token_maps.decode[x] for x in ids]

    def tokenize(self, context_str):
        """Tokenize a string to a list of `bytes` objects, each corresponding to a token in the vocabulary.

        Uses the language model's tokenizer to map `context_str` to a list of token IDs, and then decodes the token IDs to bytes.

        Args:
            context_str (str): A string to encode

        Returns:
            (List[bytes]): A list of byte tokens corresponding to the input string.
        """
        return self.decode_tokens(self.model.tokenizer.encode(context_str))

    async def log_probability(self, context):
        """
        Compute the log probability of `context` given the prompt.

        Args:
            context (list[bytes]): A sequence of bytes tokens.

        Returns:
            (float): The log probability of `context`.
        """
        if not context:
            return 0

        context_ids = self.encode_tokens(context)
        return await self._log_probability(context_ids)

    async def _log_probability(self, context_ids):
        prefixes = [self.prompt_ids + context_ids[:i] for i in range(len(context_ids))]
        log_ps = self._maybe_temper(
            await self.model.batch_next_token_logprobs(prefixes)
        )
        target_ids = torch.tensor(context_ids, device=log_ps.device)
        with torch.no_grad():
            token_logprobs = torch.gather(log_ps, 1, target_ids.unsqueeze(1))
            total_logprob = token_logprobs.sum().item()

        return total_logprob

    def _maybe_temper(self, logps):
        if self.temperature == 1:
            return logps
        return torch.log_softmax(logps / self.temperature, dim=-1)

    async def prefix(self, context):
        """
        Compute the log probability of `context` given the prompt.

        Args:
            context (list[bytes]): A sequence of bytes tokens.

        Returns:
            (float): The log probability of `context`.
        """
        return await self.log_probability(context)

    async def complete(self, context):
        """
        Compute the log probability of `context` and the eos tokens given the prompt.

        If the model has multiple eos tokens, their probabilities will be summed.

        Args:
            context (list[bytes]): A sequence of bytes tokens.

        Returns:
            (float): The log probability of the context.
        """
        context_ids = self.encode_tokens(context)
        logp_context = await self._log_probability(context_ids)
        logp_next = self._maybe_temper(
            await self.model.next_token_logprobs(self.prompt_ids + context_ids)
        )
        logp_eos = torch.logsumexp(logp_next[self.token_maps.eos_idxs], dim=0).item()
        return logp_context + logp_eos

    def _process_logw_next(self, logw_next):
        """Process the log probabilities for the next tokens.

        This function rearranges the log probabilities such that the end-of-sequence (EOS) token's log probability
        is the sum of the log probabilities of `self.eos_tokens`.

        Args:
            logw_next (torch.tensor): The log probabilities for the next tokens.

        Returns:
            (LazyWeights): Processed log probabilities for the next tokens.
        """
        # This is ugly, but it's useful for all potentials to adhere to the convention
        # of keeping the EOS token at the end of the weights array.

        # Cache eos_idxs_tensor and non_eos_indices on first use
        if (
            not hasattr(self, "_eos_idxs_tensor")
            or not hasattr(self, "_non_eos_indices")
            or self._eos_idxs_tensor.device != logw_next.device
        ):
            self._eos_idxs_tensor = torch.tensor(
                self.token_maps.eos_idxs, device=logw_next.device
            )
            all_indices = torch.arange(
                len(self.token_maps.decode), device=logw_next.device
            )
            self._non_eos_indices = all_indices[
                ~torch.isin(all_indices, self._eos_idxs_tensor)
            ]

        logw_next = logw_next[: len(self.token_maps.decode)]
        logw_next = logw_next.log_softmax(dim=0)
        _logw_next = torch.full(
            (len(self.vocab) + 1,),
            float("-inf"),
            dtype=logw_next.dtype,
            device=logw_next.device,
        )
        _logw_next[: len(self.vocab)] = logw_next[self._non_eos_indices]

        # Special case: if only one EOS idx, just assign directly (avoids cost of logsumexp)
        if self._eos_idxs_tensor.numel() == 1:
            _logw_next[-1] = logw_next[self._eos_idxs_tensor]
        else:
            _logw_next[-1] = torch.logsumexp(logw_next[self._eos_idxs_tensor], dim=0)

        return self.make_lazy_weights(_logw_next.float().cpu().numpy())

    async def logw_next(self, context):
        """Get log probabilities for next tokens given the prompt and `context`.

        Args:
            context (List[bytes]): A sequence of bytes tokens.

        Returns:
            (LazyWeights): Log probabilities for next tokens and EOS.
        """
        logw_next = self._maybe_temper(
            await self.model.next_token_logprobs(
                self.prompt_ids + self.encode_tokens(context)
            )
        )
        return self._process_logw_next(logw_next)

    async def batch_logw_next(self, contexts):
        """Get log probabilities for next tokens given the prompt and `context`, for a batch of contexts.

        Args:
            contexts (list[list[bytes]]): A list of sequences of bytes tokens.

        Returns:
            (List[LazyWeights]): Log probabilities for next tokens and EOS for each context.
        """
        logw_nexts = self._maybe_temper(
            await self.model.batch_next_token_logprobs(
                [self.prompt_ids + self.encode_tokens(context) for context in contexts]
            )
        )
        return [self._process_logw_next(logw_next) for logw_next in logw_nexts]

    def __repr__(self):
        return f"PromptedLLM(prompt={self.prompt!r})"

    def spawn(self, prompt_ids=None, eos_tokens=None, temperature=None):
        """
        Spawn a new PromptedLLM.

        Args:
            prompt_ids (optional, list[int]): The prompt to use as a prompt prefix for all input contexts.
                Defaults to the same prompt_ids as `self`.
            eos_tokens (optional, list[bytes]): A list of tokens to treat as end-of-sequence tokens.
                Defaults to the same eos_tokens as `self`.
            temperature (optional, float): The temperature with which to rescale logprobs.
                Defaults to the same temperature as `self`.

        Returns:
            (PromptedLLM): A new PromptedLLM with the same prompt and eos tokens.

        Note:
            This is a shallow copy. The new PromptedLLM will share the underlying AsyncLM instance.
        """
        prompt_ids = prompt_ids if prompt_ids is not None else self.prompt_ids.copy()
        temperature = temperature if temperature is not None else self.temperature

        if (eos_tokens is None) or (eos_tokens == self.token_maps.eos_tokens):
            # If the eos tokens don't change, we don't need to recompute the token maps or vocabulary.
            return PromptedLLM(
                self.model,
                prompt_ids=prompt_ids,
                temperature=temperature,
                token_maps=self.token_maps,
            )

        return PromptedLLM(
            self.model,
            prompt_ids=prompt_ids,
            eos_tokens=eos_tokens,
            temperature=temperature,
        )

    def spawn_new_eos(self, eos_tokens):
        """
        Create a new PromptedLLM with a different set of end-of-sequence tokens.

        Args:
            eos_tokens (list[bytes]): A list of tokens to treat as end-of-sequence tokens.

        Returns:
            (PromptedLLM): A new PromptedLLM with the specified end-of-sequence tokens.
                The new model will have the same prompt_ids as `self`.
        """
        return self.spawn(eos_tokens=eos_tokens)

    def to_autobatched(self):
        raise ValueError("PromptedLLMs are autobatched by default.")

__init__(llm, prompt_ids=None, eos_tokens=None, temperature=1, token_maps=None)

` Initializes the PromptedLLM potential.

Parameters:

Name Type Description Default
llm AsyncLM

The language model to use.

required
prompt_ids list[int]

Optional prompt to use as a prompt prefix for all input contexts. Must be a list of token IDs. Defaults to None. The prompt ids can be set post-init via prompt or prompt_ids.

None
eos_tokens list[bytes]

List of tokens to treat as end-of-sequence tokens. Defaults to the EOS token of the language model's tokenizer.

None
temperature float

The temperature to apply to the language model's logits. Defaults to 1.

1
token_maps TokenMappings

A precomputed mapping of tokens to token IDs with the potential's vocabulary. If provided, eos_tokens must not be provided. Defaults to None, which constructs a TokenMappings from the language model's byte vocabulary and the EOS tokens.

None
Source code in genlm/control/potential/built_in/llm.py
def __init__(
    self,
    llm,
    prompt_ids=None,
    eos_tokens=None,
    temperature=1,
    token_maps=None,
):
    """`
    Initializes the PromptedLLM potential.

    Args:
        llm (AsyncLM): The language model to use.
        prompt_ids (list[int], optional): Optional prompt to use as a prompt prefix for all input contexts.
            Must be a list of token IDs. Defaults to None. The prompt ids can be set post-init via `prompt` or `prompt_ids`.
        eos_tokens (list[bytes], optional): List of tokens to treat as end-of-sequence tokens.
            Defaults to the EOS token of the language model's tokenizer.
        temperature (float, optional): The temperature to apply to the language model's logits. Defaults to 1.
        token_maps (TokenMappings, optional): A precomputed mapping of tokens to token IDs with the potential's vocabulary.
            If provided, `eos_tokens` must not be provided. Defaults to None, which constructs a TokenMappings from the language model's byte vocabulary and the EOS tokens.
    """
    self.model = llm
    self.prompt_ids = prompt_ids or []
    self.temperature = temperature

    if token_maps is not None:
        if eos_tokens is not None:
            raise ValueError(
                "eos_tokens must not be provided when token_maps is provided."
            )
        self.token_maps = token_maps
    else:
        self.token_maps = TokenMappings.create(
            decode=self.model.byte_vocab,
            eos_tokens=eos_tokens
            or [self.model.byte_vocab[self.model.tokenizer.eos_token_id]],
        )

    super().__init__(vocabulary=self.token_maps.potential_vocab)

from_name(name, backend=None, eos_tokens=None, prompt_ids=None, temperature=1.0, **kwargs) classmethod

Create a PromptedLLM from a HugginFace model name.

Parameters:

Name Type Description Default
name str

Name of the model to load

required
backend str

AsyncLM backend to use:

  • 'vllm' to instantiate an AsyncVirtualLM; ideal for GPU usage

  • 'hf' for an AsyncTransformer; ideal for CPU usage

  • 'mock' for a MockAsyncLM; ideal for testing.

Defaults to 'vllm' if CUDA is available, otherwise 'hf'.

None
eos_tokens list[bytes]

List of tokens to treat as end-of-sequence tokens. Defaults to the EOS token of the language model's tokenizer.

None
prompt_ids list[int]

Optional prompt to use as a prompt prefix for all input contexts. Must be a list of token IDs. Defaults to None. The prompt ids can be set post-init via set_prompt_from_str or prompt_ids.

None
temperature float

The temperature to apply to the language model's logits. Defaults to 1.

1.0
**kwargs dict

Additional arguments passed to AsyncLM constructor

{}

Returns:

Type Description
PromptedLLM

An instance of PromptedLLM

Source code in genlm/control/potential/built_in/llm.py
@classmethod
def from_name(
    cls,
    name,
    backend=None,
    eos_tokens=None,
    prompt_ids=None,
    temperature=1.0,
    **kwargs,
):
    """Create a `PromptedLLM` from a HugginFace model name.

    Args:
        name (str): Name of the model to load
        backend (str, optional): `AsyncLM` backend to use:\n
            * 'vllm' to instantiate an `AsyncVirtualLM`; ideal for GPU usage\n
            * 'hf' for an `AsyncTransformer`; ideal for CPU usage\n
            * 'mock' for a `MockAsyncLM`; ideal for testing.\n
            Defaults to 'vllm' if CUDA is available, otherwise 'hf'.
        eos_tokens (list[bytes], optional): List of tokens to treat as end-of-sequence tokens.
            Defaults to the EOS token of the language model's tokenizer.
        prompt_ids (list[int], optional): Optional prompt to use as a prompt prefix for all input contexts.
            Must be a list of token IDs. Defaults to None. The prompt ids can be set post-init via `set_prompt_from_str` or `prompt_ids`.
        temperature (float, optional): The temperature to apply to the language model's logits. Defaults to 1.
        **kwargs (dict): Additional arguments passed to AsyncLM constructor

    Returns:
        (PromptedLLM): An instance of PromptedLLM
    """
    backend = backend or ("vllm" if torch.cuda.is_available() else "hf")
    model = load_model_by_name(name, backend=backend, **kwargs)
    return cls(
        model, prompt_ids=prompt_ids, eos_tokens=eos_tokens, temperature=temperature
    )

prompt property

Get the current prompt as a list of byte sequences corresponding to the prompt token IDs.

Returns:

Type Description
list[bytes] | None

The current prompt as a list of bytes sequences or None if no prompt_ids are set.

set_prompt_from_str(prompt_str)

Set the fixed prompt from a string.

Modifies prompt_ids to be the token IDs of the input prompt according to the language model's tokenizer.

Parameters:

Name Type Description Default
prompt_str str

The prompt to set.

required
Source code in genlm/control/potential/built_in/llm.py
def set_prompt_from_str(self, prompt_str):
    """Set the fixed prompt from a string.

    Modifies `prompt_ids` to be the token IDs of the input prompt according to the language model's tokenizer.

    Args:
        prompt_str (str): The prompt to set.
    """
    # TODO: Handle race condition where prompt_ids reset concurrently.
    if not isinstance(prompt_str, str):
        raise ValueError(
            f"Prompt must a string got {type(prompt_str)}. "
            f"To set the prompt from a list of token IDs, use prompt_ids."
        )

    if prompt_str.endswith(" "):
        warnings.warn(
            "Prompt ends with whitespace, which may affect tokenization. "
            "Consider removing trailing whitespace.",
            stacklevel=2,
        )

    self.prompt_ids = self.model.tokenizer.encode(prompt_str)

encode_tokens(tokens)

Encode a list of byte tokens to a list of token IDs in the underlying language model's vocabulary.

Parameters:

Name Type Description Default
tokens list[bytes]

List of byte tokens to encode

required

Returns:

Type Description
list[int]

A list of token IDs corresponding to the input tokens.

Raises:

Type Description
ValueError

If any token is not in the vocabulary

Source code in genlm/control/potential/built_in/llm.py
def encode_tokens(self, tokens):
    """Encode a list of byte tokens to a list of token IDs in
    the underlying language model's vocabulary.

    Args:
        tokens (list[bytes]): List of byte tokens to encode

    Returns:
        (list[int]): A list of token IDs corresponding to the input tokens.

    Raises:
        ValueError: If any token is not in the vocabulary
    """
    try:
        return [self.token_maps.encode[x] for x in tokens]
    except KeyError as e:
        raise ValueError(f"Token {e.args[0]} not in vocabulary") from e

decode_tokens(ids)

Decode a list of token IDs in the language model's vocabulary to a list of byte tokens.

Parameters:

Name Type Description Default
ids list[int]

A list of token IDs in the language model's vocabulary.

required

Returns:

Type Description
list[bytes]

A list of byte tokens corresponding to the input token IDs.

Source code in genlm/control/potential/built_in/llm.py
def decode_tokens(self, ids):
    """
    Decode a list of token IDs in the language model's vocabulary to a list of byte tokens.

    Args:
        ids (list[int]): A list of token IDs in the language model's vocabulary.

    Returns:
        (list[bytes]): A list of byte tokens corresponding to the input token IDs.
    """
    return [self.token_maps.decode[x] for x in ids]

tokenize(context_str)

Tokenize a string to a list of bytes objects, each corresponding to a token in the vocabulary.

Uses the language model's tokenizer to map context_str to a list of token IDs, and then decodes the token IDs to bytes.

Parameters:

Name Type Description Default
context_str str

A string to encode

required

Returns:

Type Description
List[bytes]

A list of byte tokens corresponding to the input string.

Source code in genlm/control/potential/built_in/llm.py
def tokenize(self, context_str):
    """Tokenize a string to a list of `bytes` objects, each corresponding to a token in the vocabulary.

    Uses the language model's tokenizer to map `context_str` to a list of token IDs, and then decodes the token IDs to bytes.

    Args:
        context_str (str): A string to encode

    Returns:
        (List[bytes]): A list of byte tokens corresponding to the input string.
    """
    return self.decode_tokens(self.model.tokenizer.encode(context_str))

log_probability(context) async

Compute the log probability of context given the prompt.

Parameters:

Name Type Description Default
context list[bytes]

A sequence of bytes tokens.

required

Returns:

Type Description
float

The log probability of context.

Source code in genlm/control/potential/built_in/llm.py
async def log_probability(self, context):
    """
    Compute the log probability of `context` given the prompt.

    Args:
        context (list[bytes]): A sequence of bytes tokens.

    Returns:
        (float): The log probability of `context`.
    """
    if not context:
        return 0

    context_ids = self.encode_tokens(context)
    return await self._log_probability(context_ids)

prefix(context) async

Compute the log probability of context given the prompt.

Parameters:

Name Type Description Default
context list[bytes]

A sequence of bytes tokens.

required

Returns:

Type Description
float

The log probability of context.

Source code in genlm/control/potential/built_in/llm.py
async def prefix(self, context):
    """
    Compute the log probability of `context` given the prompt.

    Args:
        context (list[bytes]): A sequence of bytes tokens.

    Returns:
        (float): The log probability of `context`.
    """
    return await self.log_probability(context)

complete(context) async

Compute the log probability of context and the eos tokens given the prompt.

If the model has multiple eos tokens, their probabilities will be summed.

Parameters:

Name Type Description Default
context list[bytes]

A sequence of bytes tokens.

required

Returns:

Type Description
float

The log probability of the context.

Source code in genlm/control/potential/built_in/llm.py
async def complete(self, context):
    """
    Compute the log probability of `context` and the eos tokens given the prompt.

    If the model has multiple eos tokens, their probabilities will be summed.

    Args:
        context (list[bytes]): A sequence of bytes tokens.

    Returns:
        (float): The log probability of the context.
    """
    context_ids = self.encode_tokens(context)
    logp_context = await self._log_probability(context_ids)
    logp_next = self._maybe_temper(
        await self.model.next_token_logprobs(self.prompt_ids + context_ids)
    )
    logp_eos = torch.logsumexp(logp_next[self.token_maps.eos_idxs], dim=0).item()
    return logp_context + logp_eos

logw_next(context) async

Get log probabilities for next tokens given the prompt and context.

Parameters:

Name Type Description Default
context List[bytes]

A sequence of bytes tokens.

required

Returns:

Type Description
LazyWeights

Log probabilities for next tokens and EOS.

Source code in genlm/control/potential/built_in/llm.py
async def logw_next(self, context):
    """Get log probabilities for next tokens given the prompt and `context`.

    Args:
        context (List[bytes]): A sequence of bytes tokens.

    Returns:
        (LazyWeights): Log probabilities for next tokens and EOS.
    """
    logw_next = self._maybe_temper(
        await self.model.next_token_logprobs(
            self.prompt_ids + self.encode_tokens(context)
        )
    )
    return self._process_logw_next(logw_next)

batch_logw_next(contexts) async

Get log probabilities for next tokens given the prompt and context, for a batch of contexts.

Parameters:

Name Type Description Default
contexts list[list[bytes]]

A list of sequences of bytes tokens.

required

Returns:

Type Description
List[LazyWeights]

Log probabilities for next tokens and EOS for each context.

Source code in genlm/control/potential/built_in/llm.py
async def batch_logw_next(self, contexts):
    """Get log probabilities for next tokens given the prompt and `context`, for a batch of contexts.

    Args:
        contexts (list[list[bytes]]): A list of sequences of bytes tokens.

    Returns:
        (List[LazyWeights]): Log probabilities for next tokens and EOS for each context.
    """
    logw_nexts = self._maybe_temper(
        await self.model.batch_next_token_logprobs(
            [self.prompt_ids + self.encode_tokens(context) for context in contexts]
        )
    )
    return [self._process_logw_next(logw_next) for logw_next in logw_nexts]

spawn(prompt_ids=None, eos_tokens=None, temperature=None)

Spawn a new PromptedLLM.

Parameters:

Name Type Description Default
prompt_ids (optional, list[int])

The prompt to use as a prompt prefix for all input contexts. Defaults to the same prompt_ids as self.

None
eos_tokens (optional, list[bytes])

A list of tokens to treat as end-of-sequence tokens. Defaults to the same eos_tokens as self.

None
temperature (optional, float)

The temperature with which to rescale logprobs. Defaults to the same temperature as self.

None

Returns:

Type Description
PromptedLLM

A new PromptedLLM with the same prompt and eos tokens.

Note

This is a shallow copy. The new PromptedLLM will share the underlying AsyncLM instance.

Source code in genlm/control/potential/built_in/llm.py
def spawn(self, prompt_ids=None, eos_tokens=None, temperature=None):
    """
    Spawn a new PromptedLLM.

    Args:
        prompt_ids (optional, list[int]): The prompt to use as a prompt prefix for all input contexts.
            Defaults to the same prompt_ids as `self`.
        eos_tokens (optional, list[bytes]): A list of tokens to treat as end-of-sequence tokens.
            Defaults to the same eos_tokens as `self`.
        temperature (optional, float): The temperature with which to rescale logprobs.
            Defaults to the same temperature as `self`.

    Returns:
        (PromptedLLM): A new PromptedLLM with the same prompt and eos tokens.

    Note:
        This is a shallow copy. The new PromptedLLM will share the underlying AsyncLM instance.
    """
    prompt_ids = prompt_ids if prompt_ids is not None else self.prompt_ids.copy()
    temperature = temperature if temperature is not None else self.temperature

    if (eos_tokens is None) or (eos_tokens == self.token_maps.eos_tokens):
        # If the eos tokens don't change, we don't need to recompute the token maps or vocabulary.
        return PromptedLLM(
            self.model,
            prompt_ids=prompt_ids,
            temperature=temperature,
            token_maps=self.token_maps,
        )

    return PromptedLLM(
        self.model,
        prompt_ids=prompt_ids,
        eos_tokens=eos_tokens,
        temperature=temperature,
    )

spawn_new_eos(eos_tokens)

Create a new PromptedLLM with a different set of end-of-sequence tokens.

Parameters:

Name Type Description Default
eos_tokens list[bytes]

A list of tokens to treat as end-of-sequence tokens.

required

Returns:

Type Description
PromptedLLM

A new PromptedLLM with the specified end-of-sequence tokens. The new model will have the same prompt_ids as self.

Source code in genlm/control/potential/built_in/llm.py
def spawn_new_eos(self, eos_tokens):
    """
    Create a new PromptedLLM with a different set of end-of-sequence tokens.

    Args:
        eos_tokens (list[bytes]): A list of tokens to treat as end-of-sequence tokens.

    Returns:
        (PromptedLLM): A new PromptedLLM with the specified end-of-sequence tokens.
            The new model will have the same prompt_ids as `self`.
    """
    return self.spawn(eos_tokens=eos_tokens)

BoolCFG

Bases: Potential

BoolCFG represents a boolean context-free grammar.

Source code in genlm/control/potential/built_in/wcfg.py
class BoolCFG(Potential):
    """BoolCFG represents a boolean context-free grammar."""

    def __init__(self, cfg):
        if cfg.R != Boolean:
            cfg = cfg.map_values(lambda x: Boolean(x > 0), Boolean)
        self.cfg = cfg  # cfg before prefix transform
        self.cfg_eos = _add_eos(cfg, EOS)  # augmented with eos
        self.model = Earley(self.cfg_eos.prefix_grammar)
        super().__init__(vocabulary=list(cfg.V))

    @classmethod
    def from_lark(cls, lark_string, charset="core"):
        """
        Create a BoolCFG instance from a Lark grammar string.

        The output grammar will be defined at the byte-level.

        Args:
            lark_string (str): The Lark grammar string to parse. See Lark documentation for correct syntax.
            charset (str): The character set to use. Defaults to "core".
                See `genlm-grammar` documentation for more details.

        Returns:
            (BoolCFG): An instance of BoolCFG created from the provided Lark grammar.
        """
        byte_cfg = LarkStuff(lark_string).byte_cfg(charset=charset)
        return cls(byte_cfg)

    async def complete(self, context):
        """
        Checks whether the context is accepted by the CFG.

        Args:
            context (list): A sequence of tokens in the CFG's alphabet.

        Returns:
            (float): Log weight for whether `context` is accepted by the CFG.
        """
        w = self.model([*context, EOS])
        return 0 if w.score else float("-inf")

    async def prefix(self, context):
        """
        Checks whether `context` is accepted as a prefix by the CFG, i.e.,
        whether there exists a completion to `context` that is accepted by the CFG.

        Args:
            context (list): A sequence of tokens in the CFG's alphabet.

        Returns:
            (float): Log weight for whether `context` is accepted as a prefix by the CFG.
        """
        if not context:  # FIX: this is a hack to handle the empty string because genlm-grammar doesn't support it
            return 0
        w = self.model(context)
        return 0 if w.score else float("-inf")

    async def logw_next(self, context):
        """
        Compute the next token log weights given `context`.

        Args:
            context (list): A sequence of tokens in the CFG's alphabet.

        Returns:
            (LazyWeights): The log weights for the next tokens and EOS given `context`.
        """
        ws = self.model.next_token_weights(self.model.chart(context))
        log_ws = np.array([0 if ws[x].score else float("-inf") for x in self.vocab_eos])
        return self.make_lazy_weights(log_ws)

    async def batch_logw_next(self, contexts):
        """
        Batch version of `logw_next`.

        Args:
            contexts (list): A list of sequences of tokens in the CFG's alphabet.

        Returns:
            (list): A list of log-weights for next token, one per context.
        """
        Ws = []
        for context in contexts:
            ws = self.model.next_token_weights(self.model.chart(context))
            log_ws = np.array(
                [0 if ws[x].score else float("-inf") for x in self.vocab_eos]
            )
            Ws.append(self.make_lazy_weights(log_ws))
        return Ws

    def spawn(self):
        """Spawn a new BoolCFG."""
        return BoolCFG(self.cfg)

    def clear_cache(self):
        """Clear the internal cache of the parser."""
        self.model.clear_cache()

    def __repr__(self):
        return f"BoolCFG(cfg={self.cfg!r})"

    def _repr_html_(self):
        return self.cfg._repr_html_()

from_lark(lark_string, charset='core') classmethod

Create a BoolCFG instance from a Lark grammar string.

The output grammar will be defined at the byte-level.

Parameters:

Name Type Description Default
lark_string str

The Lark grammar string to parse. See Lark documentation for correct syntax.

required
charset str

The character set to use. Defaults to "core". See genlm-grammar documentation for more details.

'core'

Returns:

Type Description
BoolCFG

An instance of BoolCFG created from the provided Lark grammar.

Source code in genlm/control/potential/built_in/wcfg.py
@classmethod
def from_lark(cls, lark_string, charset="core"):
    """
    Create a BoolCFG instance from a Lark grammar string.

    The output grammar will be defined at the byte-level.

    Args:
        lark_string (str): The Lark grammar string to parse. See Lark documentation for correct syntax.
        charset (str): The character set to use. Defaults to "core".
            See `genlm-grammar` documentation for more details.

    Returns:
        (BoolCFG): An instance of BoolCFG created from the provided Lark grammar.
    """
    byte_cfg = LarkStuff(lark_string).byte_cfg(charset=charset)
    return cls(byte_cfg)

complete(context) async

Checks whether the context is accepted by the CFG.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the CFG's alphabet.

required

Returns:

Type Description
float

Log weight for whether context is accepted by the CFG.

Source code in genlm/control/potential/built_in/wcfg.py
async def complete(self, context):
    """
    Checks whether the context is accepted by the CFG.

    Args:
        context (list): A sequence of tokens in the CFG's alphabet.

    Returns:
        (float): Log weight for whether `context` is accepted by the CFG.
    """
    w = self.model([*context, EOS])
    return 0 if w.score else float("-inf")

prefix(context) async

Checks whether context is accepted as a prefix by the CFG, i.e., whether there exists a completion to context that is accepted by the CFG.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the CFG's alphabet.

required

Returns:

Type Description
float

Log weight for whether context is accepted as a prefix by the CFG.

Source code in genlm/control/potential/built_in/wcfg.py
async def prefix(self, context):
    """
    Checks whether `context` is accepted as a prefix by the CFG, i.e.,
    whether there exists a completion to `context` that is accepted by the CFG.

    Args:
        context (list): A sequence of tokens in the CFG's alphabet.

    Returns:
        (float): Log weight for whether `context` is accepted as a prefix by the CFG.
    """
    if not context:  # FIX: this is a hack to handle the empty string because genlm-grammar doesn't support it
        return 0
    w = self.model(context)
    return 0 if w.score else float("-inf")

logw_next(context) async

Compute the next token log weights given context.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the CFG's alphabet.

required

Returns:

Type Description
LazyWeights

The log weights for the next tokens and EOS given context.

Source code in genlm/control/potential/built_in/wcfg.py
async def logw_next(self, context):
    """
    Compute the next token log weights given `context`.

    Args:
        context (list): A sequence of tokens in the CFG's alphabet.

    Returns:
        (LazyWeights): The log weights for the next tokens and EOS given `context`.
    """
    ws = self.model.next_token_weights(self.model.chart(context))
    log_ws = np.array([0 if ws[x].score else float("-inf") for x in self.vocab_eos])
    return self.make_lazy_weights(log_ws)

batch_logw_next(contexts) async

Batch version of logw_next.

Parameters:

Name Type Description Default
contexts list

A list of sequences of tokens in the CFG's alphabet.

required

Returns:

Type Description
list

A list of log-weights for next token, one per context.

Source code in genlm/control/potential/built_in/wcfg.py
async def batch_logw_next(self, contexts):
    """
    Batch version of `logw_next`.

    Args:
        contexts (list): A list of sequences of tokens in the CFG's alphabet.

    Returns:
        (list): A list of log-weights for next token, one per context.
    """
    Ws = []
    for context in contexts:
        ws = self.model.next_token_weights(self.model.chart(context))
        log_ws = np.array(
            [0 if ws[x].score else float("-inf") for x in self.vocab_eos]
        )
        Ws.append(self.make_lazy_weights(log_ws))
    return Ws

spawn()

Spawn a new BoolCFG.

Source code in genlm/control/potential/built_in/wcfg.py
def spawn(self):
    """Spawn a new BoolCFG."""
    return BoolCFG(self.cfg)

clear_cache()

Clear the internal cache of the parser.

Source code in genlm/control/potential/built_in/wcfg.py
def clear_cache(self):
    """Clear the internal cache of the parser."""
    self.model.clear_cache()

BoolFSA

Bases: WFSA

Boolean FSA potential.

Source code in genlm/control/potential/built_in/wfsa.py
class BoolFSA(WFSA):
    """Boolean FSA potential."""

    async def prefix(self, context):
        """
        Computes whether the context is accepted as a prefix by the FSA.

        Args:
            context (list): A sequence of tokens in the WFSA's alphabet.

        Returns:
            (float): `0` if the context is accepted as a prefix, `-inf` otherwise.
        """
        prefix_w = await super().prefix(context)
        if prefix_w > float("-inf"):
            return 0
        return float("-inf")

    async def complete(self, context):
        """
        Computes whether the context is accepted by the FSA.

        Args:
            context (list): A sequence of tokens in the WFSA's alphabet.

        Returns:
            (float): `0` if the context is accepted, `-inf` otherwise.
        """
        complete_w = await super().complete(context)
        if complete_w > float("-inf"):
            return 0
        return float("-inf")

    async def logw_next(self, context):
        """
        Returns next token log weights given `context`.

        Args:
            context (list): A sequence of tokens in the WFSA's alphabet.

        Returns:
            (LazyWeights): Boolean log-weights for next token.
        """
        logw_next = await super().logw_next(context)
        return logw_next.spawn(
            new_weights=np.where(
                logw_next.weights > float("-inf"), 0, logw_next.weights
            )
        )

    async def batch_logw_next(self, contexts):
        """
        Returns next token log weights for a batch of contexts.

        Args:
            contexts (list): The list of contexts.

        Returns:
            (list): List of log-weights for next token, one per context.
        """
        logw_nexts = await super().batch_logw_next(contexts)
        return [
            logw_next.spawn(
                new_weights=np.where(
                    logw_next.weights > float("-inf"), 0, logw_next.weights
                )
            )
            for logw_next in logw_nexts
        ]

    def __repr__(self):
        return f"BoolFSA(wfsa={self.wfsa!r})"

prefix(context) async

Computes whether the context is accepted as a prefix by the FSA.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the WFSA's alphabet.

required

Returns:

Type Description
float

0 if the context is accepted as a prefix, -inf otherwise.

Source code in genlm/control/potential/built_in/wfsa.py
async def prefix(self, context):
    """
    Computes whether the context is accepted as a prefix by the FSA.

    Args:
        context (list): A sequence of tokens in the WFSA's alphabet.

    Returns:
        (float): `0` if the context is accepted as a prefix, `-inf` otherwise.
    """
    prefix_w = await super().prefix(context)
    if prefix_w > float("-inf"):
        return 0
    return float("-inf")

complete(context) async

Computes whether the context is accepted by the FSA.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the WFSA's alphabet.

required

Returns:

Type Description
float

0 if the context is accepted, -inf otherwise.

Source code in genlm/control/potential/built_in/wfsa.py
async def complete(self, context):
    """
    Computes whether the context is accepted by the FSA.

    Args:
        context (list): A sequence of tokens in the WFSA's alphabet.

    Returns:
        (float): `0` if the context is accepted, `-inf` otherwise.
    """
    complete_w = await super().complete(context)
    if complete_w > float("-inf"):
        return 0
    return float("-inf")

logw_next(context) async

Returns next token log weights given context.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the WFSA's alphabet.

required

Returns:

Type Description
LazyWeights

Boolean log-weights for next token.

Source code in genlm/control/potential/built_in/wfsa.py
async def logw_next(self, context):
    """
    Returns next token log weights given `context`.

    Args:
        context (list): A sequence of tokens in the WFSA's alphabet.

    Returns:
        (LazyWeights): Boolean log-weights for next token.
    """
    logw_next = await super().logw_next(context)
    return logw_next.spawn(
        new_weights=np.where(
            logw_next.weights > float("-inf"), 0, logw_next.weights
        )
    )

batch_logw_next(contexts) async

Returns next token log weights for a batch of contexts.

Parameters:

Name Type Description Default
contexts list

The list of contexts.

required

Returns:

Type Description
list

List of log-weights for next token, one per context.

Source code in genlm/control/potential/built_in/wfsa.py
async def batch_logw_next(self, contexts):
    """
    Returns next token log weights for a batch of contexts.

    Args:
        contexts (list): The list of contexts.

    Returns:
        (list): List of log-weights for next token, one per context.
    """
    logw_nexts = await super().batch_logw_next(contexts)
    return [
        logw_next.spawn(
            new_weights=np.where(
                logw_next.weights > float("-inf"), 0, logw_next.weights
            )
        )
        for logw_next in logw_nexts
    ]

WFSA

Bases: Potential

A weighted finite state automaton (WFSA) potential.

This class wraps a genlm_grammar.WFSA and provides methods for computing the log-weight of a context, the prefix log-weight of a context, and the log-weights of the next token given a context.

Attributes:

Name Type Description
wfsa WFSA

The weighted finite state automaton used for potential calculations.

Source code in genlm/control/potential/built_in/wfsa.py
class WFSA(Potential):
    """
    A weighted finite state automaton (WFSA) potential.

    This class wraps a `genlm_grammar.WFSA` and provides methods for computing the log-weight of a context,
    the prefix log-weight of a context, and the log-weights of the next token given a context.

    Attributes:
        wfsa (genlm_grammar.WFSA): The weighted finite state automaton used for potential calculations.
    """

    def __init__(self, wfsa):
        """
        Initializes the WFSA potential.

        Args:
            wfsa (genlm_grammar.WFSA): The weighted finite state automaton.

        Raises:
            ValueError: If the semiring of the provided WFSA is not Float or Log.

        Note:
            The WFSA will be converted to the Log semiring to avoid underflow if the semiring is Float.
        """
        if wfsa.R not in (Float, Log):
            raise ValueError(f"Unsupported semiring: {wfsa.R}")

        if wfsa.R is Float:
            self.wfsa = self._convert_to_log(wfsa)
        else:
            self.wfsa = wfsa

        self.cache = {(): self.wfsa.epsremove.start}
        super().__init__(vocabulary=list(self.wfsa.alphabet))

    @classmethod
    def from_regex(cls, pattern, charset=None, to_bytes=True):
        """
        Create a WFSA from a regex pattern.

        Args:
            pattern (str): The regex pattern to convert into a WFSA.
            charset (set): The character set to use for negative character classes.
                Defaults to characters in string.printable.
            to_bytes (bool): Whether to convert the WFSA transitions to bytes.
                Defaults to True. When set to False, the WFSA transitions will be strings.

        Returns:
            (WFSA): An instance of the WFSA class.

        Note:
            The transition weights are automatically normalized to form a probability distribution.
            For each state, the weights of all outgoing transitions (including final state transitions)
            sum to 1.0. This means if a state has n possible transitions, each transition will have
            weight 1/n. To create a WFSA from a regex with non-probabilistic transitions, use `BoolFSA`.
        """
        charset = charset or set(string.printable)
        wfsa = interegular_to_wfsa(pattern, charset=charset)
        if to_bytes:
            wfsa = wfsa.to_bytes()
        return cls(wfsa=wfsa)

    @staticmethod
    def _convert_to_log(wfsa):
        """Convert a WFSA from the Float semiring to the Log semiring."""
        assert wfsa.R is Float
        assert isinstance(wfsa, BaseWFSA)
        new = BaseWFSA(Log)

        for i, w in wfsa.I:
            new.add_I(i, Log(np.log(w)))

        for i, w in wfsa.F:
            new.add_F(i, Log(np.log(w)))

        for i, a, j, w in wfsa.arcs():
            new.add_arc(i, a, j, Log(np.log(w)))

        return new

    def _consume(self, bs):
        # XXX implement cache eviction
        bs = tuple(bs)

        try:
            return self.cache[bs]
        except KeyError:
            pass

        wfsa = self.wfsa.epsremove
        curr = wfsa.R.chart()
        prev = self._consume(bs[:-1])
        for i in prev:
            for j, w in wfsa.arcs(i, bs[-1]):
                curr[j] += prev[i] * w

        self.cache[bs] = curr

        return curr

    async def complete(self, context):
        """
        Computes the log weight of the context under the weighted language represented by the WFSA.

        For example, if the WFSA accepts "cat" and "car" with weights $w_{cat}$ and $w_{car}$:\n
        - `complete("c")` returns $-\\infty$ since this sequence is not accepted by the WFSA\n
        - `complete("cat")` returns $\\log(w_{cat})$\n
        - `complete("d")` returns $-\\infty$ since this sequence is not accepted by the WFSA

        Args:
            context (list): A sequence of tokens in the WFSA's alphabet.

        Returns:
            (float): Log weight of context under the WFSA.
        """
        # TODO: optimize to use _consume cache
        return self.wfsa(context).score

    def _prefix(self, context):
        curr = self._consume(context)

        if not curr:
            return float("-inf"), curr

        bkwd = self.wfsa.epsremove.backward
        log_ctx_w = logsumexp([(curr[i] * bkwd[i]).score for i in curr])

        if np.isnan(log_ctx_w):
            return float("-inf"), curr

        return log_ctx_w, curr

    async def prefix(self, context):
        """
        Computes the prefix log weight of `context` under the WFSA.

        This corresponds to the log of the sum of the weights of all sequences with prefix `context`.

        For example, if the WFSA accepts "cat" and "car" with weights $w_{cat}$ and $w_{car}$:\n
        - `prefix("c")` returns $\\log(w_{cat} + w_{car})$\n
        - `prefix("ca")` returns $\\log(w_{cat})$\n
        - `prefix("d")` returns $-\\infty$ since the WFSA does not accept any sequences with prefix "d"

        Args:
            context (list): A sequence of tokens in the WFSA's alphabet.

        Returns:
            (float): Log weight of `context` as a prefix under the WFSA.
        """
        return self._prefix(context)[0]

    async def logw_next(self, context):
        """Returns next token log weights given `context`.

        Args:
            context (list): A sequence of tokens in the WFSA's alphabet.

        Returns:
            (LazyWeights): Log-weights for next token and EOS.
        """
        log_ctx_w, curr = self._prefix(context)

        if log_ctx_w == float("-inf"):
            raise ValueError(f"Context {context!r} has zero weight.")

        bkwd = self.wfsa.epsremove.backward

        ws = self.wfsa.R.chart()
        for i in curr:
            for b, j, w in self.wfsa.epsremove.arcs(i=i):
                ws[b] += curr[i] * w * bkwd[j]

        ws[self.eos] = self.wfsa.R.zero
        for j, w in self.wfsa.epsremove.F:
            ws[self.eos] += curr[j] * w

        log_ws = np.array([ws[b].score for b in self.vocab_eos]) - log_ctx_w

        return self.make_lazy_weights(log_ws)

    def _repr_svg_(self):
        return self.wfsa._repr_svg_()

    def __repr__(self):
        return f"WFSA(wfsa={self.wfsa!r})"

    def spawn(self):
        cls = type(self)
        return cls(wfsa=self.wfsa)

    def clear_cache(self):
        self.cache = {(): self.wfsa.epsremove.start}

__init__(wfsa)

Initializes the WFSA potential.

Parameters:

Name Type Description Default
wfsa WFSA

The weighted finite state automaton.

required

Raises:

Type Description
ValueError

If the semiring of the provided WFSA is not Float or Log.

Note

The WFSA will be converted to the Log semiring to avoid underflow if the semiring is Float.

Source code in genlm/control/potential/built_in/wfsa.py
def __init__(self, wfsa):
    """
    Initializes the WFSA potential.

    Args:
        wfsa (genlm_grammar.WFSA): The weighted finite state automaton.

    Raises:
        ValueError: If the semiring of the provided WFSA is not Float or Log.

    Note:
        The WFSA will be converted to the Log semiring to avoid underflow if the semiring is Float.
    """
    if wfsa.R not in (Float, Log):
        raise ValueError(f"Unsupported semiring: {wfsa.R}")

    if wfsa.R is Float:
        self.wfsa = self._convert_to_log(wfsa)
    else:
        self.wfsa = wfsa

    self.cache = {(): self.wfsa.epsremove.start}
    super().__init__(vocabulary=list(self.wfsa.alphabet))

from_regex(pattern, charset=None, to_bytes=True) classmethod

Create a WFSA from a regex pattern.

Parameters:

Name Type Description Default
pattern str

The regex pattern to convert into a WFSA.

required
charset set

The character set to use for negative character classes. Defaults to characters in string.printable.

None
to_bytes bool

Whether to convert the WFSA transitions to bytes. Defaults to True. When set to False, the WFSA transitions will be strings.

True

Returns:

Type Description
WFSA

An instance of the WFSA class.

Note

The transition weights are automatically normalized to form a probability distribution. For each state, the weights of all outgoing transitions (including final state transitions) sum to 1.0. This means if a state has n possible transitions, each transition will have weight 1/n. To create a WFSA from a regex with non-probabilistic transitions, use BoolFSA.

Source code in genlm/control/potential/built_in/wfsa.py
@classmethod
def from_regex(cls, pattern, charset=None, to_bytes=True):
    """
    Create a WFSA from a regex pattern.

    Args:
        pattern (str): The regex pattern to convert into a WFSA.
        charset (set): The character set to use for negative character classes.
            Defaults to characters in string.printable.
        to_bytes (bool): Whether to convert the WFSA transitions to bytes.
            Defaults to True. When set to False, the WFSA transitions will be strings.

    Returns:
        (WFSA): An instance of the WFSA class.

    Note:
        The transition weights are automatically normalized to form a probability distribution.
        For each state, the weights of all outgoing transitions (including final state transitions)
        sum to 1.0. This means if a state has n possible transitions, each transition will have
        weight 1/n. To create a WFSA from a regex with non-probabilistic transitions, use `BoolFSA`.
    """
    charset = charset or set(string.printable)
    wfsa = interegular_to_wfsa(pattern, charset=charset)
    if to_bytes:
        wfsa = wfsa.to_bytes()
    return cls(wfsa=wfsa)

complete(context) async

Computes the log weight of the context under the weighted language represented by the WFSA.

For example, if the WFSA accepts "cat" and "car" with weights \(w_{cat}\) and \(w_{car}\):

  • complete("c") returns \(-\infty\) since this sequence is not accepted by the WFSA

  • complete("cat") returns \(\log(w_{cat})\)

  • complete("d") returns \(-\infty\) since this sequence is not accepted by the WFSA

Parameters:

Name Type Description Default
context list

A sequence of tokens in the WFSA's alphabet.

required

Returns:

Type Description
float

Log weight of context under the WFSA.

Source code in genlm/control/potential/built_in/wfsa.py
async def complete(self, context):
    """
    Computes the log weight of the context under the weighted language represented by the WFSA.

    For example, if the WFSA accepts "cat" and "car" with weights $w_{cat}$ and $w_{car}$:\n
    - `complete("c")` returns $-\\infty$ since this sequence is not accepted by the WFSA\n
    - `complete("cat")` returns $\\log(w_{cat})$\n
    - `complete("d")` returns $-\\infty$ since this sequence is not accepted by the WFSA

    Args:
        context (list): A sequence of tokens in the WFSA's alphabet.

    Returns:
        (float): Log weight of context under the WFSA.
    """
    # TODO: optimize to use _consume cache
    return self.wfsa(context).score

prefix(context) async

Computes the prefix log weight of context under the WFSA.

This corresponds to the log of the sum of the weights of all sequences with prefix context.

For example, if the WFSA accepts "cat" and "car" with weights \(w_{cat}\) and \(w_{car}\):

  • prefix("c") returns \(\log(w_{cat} + w_{car})\)

  • prefix("ca") returns \(\log(w_{cat})\)

  • prefix("d") returns \(-\infty\) since the WFSA does not accept any sequences with prefix "d"

Parameters:

Name Type Description Default
context list

A sequence of tokens in the WFSA's alphabet.

required

Returns:

Type Description
float

Log weight of context as a prefix under the WFSA.

Source code in genlm/control/potential/built_in/wfsa.py
async def prefix(self, context):
    """
    Computes the prefix log weight of `context` under the WFSA.

    This corresponds to the log of the sum of the weights of all sequences with prefix `context`.

    For example, if the WFSA accepts "cat" and "car" with weights $w_{cat}$ and $w_{car}$:\n
    - `prefix("c")` returns $\\log(w_{cat} + w_{car})$\n
    - `prefix("ca")` returns $\\log(w_{cat})$\n
    - `prefix("d")` returns $-\\infty$ since the WFSA does not accept any sequences with prefix "d"

    Args:
        context (list): A sequence of tokens in the WFSA's alphabet.

    Returns:
        (float): Log weight of `context` as a prefix under the WFSA.
    """
    return self._prefix(context)[0]

logw_next(context) async

Returns next token log weights given context.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the WFSA's alphabet.

required

Returns:

Type Description
LazyWeights

Log-weights for next token and EOS.

Source code in genlm/control/potential/built_in/wfsa.py
async def logw_next(self, context):
    """Returns next token log weights given `context`.

    Args:
        context (list): A sequence of tokens in the WFSA's alphabet.

    Returns:
        (LazyWeights): Log-weights for next token and EOS.
    """
    log_ctx_w, curr = self._prefix(context)

    if log_ctx_w == float("-inf"):
        raise ValueError(f"Context {context!r} has zero weight.")

    bkwd = self.wfsa.epsremove.backward

    ws = self.wfsa.R.chart()
    for i in curr:
        for b, j, w in self.wfsa.epsremove.arcs(i=i):
            ws[b] += curr[i] * w * bkwd[j]

    ws[self.eos] = self.wfsa.R.zero
    for j, w in self.wfsa.epsremove.F:
        ws[self.eos] += curr[j] * w

    log_ws = np.array([ws[b].score for b in self.vocab_eos]) - log_ctx_w

    return self.make_lazy_weights(log_ws)

WCFG

Bases: Potential

A weighted context-free grammar potential.

This class wraps a genlm_grammar.CFG and provides methods for computing the log-weight of a sequence, the prefix log-weight of a sequence, and the log-weights of the next token given a sequence.

Source code in genlm/control/potential/built_in/wcfg.py
class WCFG(Potential):
    """
    A weighted context-free grammar potential.

    This class wraps a `genlm_grammar.CFG` and provides methods for computing the log-weight of a sequence,
    the prefix log-weight of a sequence, and the log-weights of the next token given a sequence.
    """

    def __init__(self, cfg):
        """
        Initialize the WCFG potential.

        Args:
            cfg (genlm_grammar.CFG): The context-free grammar configuration to use.
                The CFG must in the Float semiring.
        """
        # TODO: convert to LogSemiring to handle underflow
        if cfg.R is not Float:
            raise ValueError("cfg semiring must be Float")
        self.cfg = cfg  # cfg before prefix transform
        self.cfg_eos = _add_eos(cfg, EOS)  # augmented with eos
        self.model = Earley(self.cfg_eos.prefix_grammar)
        super().__init__(vocabulary=list(cfg.V))

    @classmethod
    def from_string(cls, grammar, to_bytes=True, **kwargs):
        """Create a WCFG from a string.

        Args:
            grammar (str): The string grammar specification to create the WCFG from.
            to_bytes (bool, optional): Whether to convert the WCFG terminals to indivudual bytes.
                Defaults to True.
            **kwargs (dict): Additional arguments passed to the WCFG constructor.

        Returns:
            (WCFG): The created WCFG.
        """
        cfg = CFG.from_string(grammar, Float)
        if to_bytes:
            cfg = cfg.to_bytes()
        return cls(cfg, **kwargs)

    async def complete(self, context):
        """
        Compute the log weight of `context` under the WCFG.

        For example, if the WCFG accepts "cat" and "car" with weights $w_{cat}$ and $w_{car}$:\n
        - `complete("c")` returns $-\\infty$ since this sequence is not accepted by the WCFG\n
        - `complete("cat")` returns $\\log(w_{cat})$\n
        - `complete("d")` returns $-\\infty$ since this sequence is not accepted by the WCFG

        Args:
            context (list): A sequence of tokens in the WCFG's alphabet.

        Returns:
            (float): The log weight of `context` under the WCFG.
        """
        w = self.model([*context, EOS])
        return np.log(w) if w > 0 else float("-inf")

    async def prefix(self, context):
        """
        Compute the log prefix weight of `context` under the WCFG.

        This corresponds to the log of the sum of the weights of all sequences with prefix `context`.

        For example, if the WCFG accepts "cat" and "car" with weights $w_{cat}$ and $w_{car}$:\n
        - `prefix("c")` returns $\\log(w_{cat} + w_{car})$\n
        - `prefix("cat")` returns $\\log(w_{cat})$\n
        - `prefix("d")` returns $-\\infty$ since the WCFG does not accept any sequences with prefix "d"

        Args:
            context (list): A sequence of tokens in the WCFG's alphabet.

        Returns:
            (float): The log prefix weight of `context` under the WCFG.
        """
        w = self.model(context)
        return np.log(w) if w > 0 else float("-inf")

    async def logw_next(self, context):
        """
        Compute the next token log weights given `context`.

        Args:
            context (list): A sequence of tokens in the WCFG's alphabet.

        Returns:
            (LazyWeights): The log weights for the next tokens and EOS given `context`.
        """
        ws = self.model.next_token_weights(self.model.chart(context))
        ws = ws.trim().normalize()

        ws_array = np.array([ws[x] for x in self.vocab_eos])
        mask = ws_array > 0
        log_ws = np.full_like(ws_array, float("-inf"), dtype=np.float64)
        log_ws[mask] = np.log(ws_array[mask])

        return self.make_lazy_weights(log_ws)

    def clear_cache(self):
        """Clear the internal cache of the parser."""
        self.model.clear_cache()

    def __repr__(self):
        return f"WCFG(cfg={self.cfg!r})"

    def _repr_html_(self):
        return self.cfg._repr_html_()

    def spawn(self):
        """Spawn a new WCFG."""
        return WCFG(self.cfg)

__init__(cfg)

Initialize the WCFG potential.

Parameters:

Name Type Description Default
cfg CFG

The context-free grammar configuration to use. The CFG must in the Float semiring.

required
Source code in genlm/control/potential/built_in/wcfg.py
def __init__(self, cfg):
    """
    Initialize the WCFG potential.

    Args:
        cfg (genlm_grammar.CFG): The context-free grammar configuration to use.
            The CFG must in the Float semiring.
    """
    # TODO: convert to LogSemiring to handle underflow
    if cfg.R is not Float:
        raise ValueError("cfg semiring must be Float")
    self.cfg = cfg  # cfg before prefix transform
    self.cfg_eos = _add_eos(cfg, EOS)  # augmented with eos
    self.model = Earley(self.cfg_eos.prefix_grammar)
    super().__init__(vocabulary=list(cfg.V))

from_string(grammar, to_bytes=True, **kwargs) classmethod

Create a WCFG from a string.

Parameters:

Name Type Description Default
grammar str

The string grammar specification to create the WCFG from.

required
to_bytes bool

Whether to convert the WCFG terminals to indivudual bytes. Defaults to True.

True
**kwargs dict

Additional arguments passed to the WCFG constructor.

{}

Returns:

Type Description
WCFG

The created WCFG.

Source code in genlm/control/potential/built_in/wcfg.py
@classmethod
def from_string(cls, grammar, to_bytes=True, **kwargs):
    """Create a WCFG from a string.

    Args:
        grammar (str): The string grammar specification to create the WCFG from.
        to_bytes (bool, optional): Whether to convert the WCFG terminals to indivudual bytes.
            Defaults to True.
        **kwargs (dict): Additional arguments passed to the WCFG constructor.

    Returns:
        (WCFG): The created WCFG.
    """
    cfg = CFG.from_string(grammar, Float)
    if to_bytes:
        cfg = cfg.to_bytes()
    return cls(cfg, **kwargs)

complete(context) async

Compute the log weight of context under the WCFG.

For example, if the WCFG accepts "cat" and "car" with weights \(w_{cat}\) and \(w_{car}\):

  • complete("c") returns \(-\infty\) since this sequence is not accepted by the WCFG

  • complete("cat") returns \(\log(w_{cat})\)

  • complete("d") returns \(-\infty\) since this sequence is not accepted by the WCFG

Parameters:

Name Type Description Default
context list

A sequence of tokens in the WCFG's alphabet.

required

Returns:

Type Description
float

The log weight of context under the WCFG.

Source code in genlm/control/potential/built_in/wcfg.py
async def complete(self, context):
    """
    Compute the log weight of `context` under the WCFG.

    For example, if the WCFG accepts "cat" and "car" with weights $w_{cat}$ and $w_{car}$:\n
    - `complete("c")` returns $-\\infty$ since this sequence is not accepted by the WCFG\n
    - `complete("cat")` returns $\\log(w_{cat})$\n
    - `complete("d")` returns $-\\infty$ since this sequence is not accepted by the WCFG

    Args:
        context (list): A sequence of tokens in the WCFG's alphabet.

    Returns:
        (float): The log weight of `context` under the WCFG.
    """
    w = self.model([*context, EOS])
    return np.log(w) if w > 0 else float("-inf")

prefix(context) async

Compute the log prefix weight of context under the WCFG.

This corresponds to the log of the sum of the weights of all sequences with prefix context.

For example, if the WCFG accepts "cat" and "car" with weights \(w_{cat}\) and \(w_{car}\):

  • prefix("c") returns \(\log(w_{cat} + w_{car})\)

  • prefix("cat") returns \(\log(w_{cat})\)

  • prefix("d") returns \(-\infty\) since the WCFG does not accept any sequences with prefix "d"

Parameters:

Name Type Description Default
context list

A sequence of tokens in the WCFG's alphabet.

required

Returns:

Type Description
float

The log prefix weight of context under the WCFG.

Source code in genlm/control/potential/built_in/wcfg.py
async def prefix(self, context):
    """
    Compute the log prefix weight of `context` under the WCFG.

    This corresponds to the log of the sum of the weights of all sequences with prefix `context`.

    For example, if the WCFG accepts "cat" and "car" with weights $w_{cat}$ and $w_{car}$:\n
    - `prefix("c")` returns $\\log(w_{cat} + w_{car})$\n
    - `prefix("cat")` returns $\\log(w_{cat})$\n
    - `prefix("d")` returns $-\\infty$ since the WCFG does not accept any sequences with prefix "d"

    Args:
        context (list): A sequence of tokens in the WCFG's alphabet.

    Returns:
        (float): The log prefix weight of `context` under the WCFG.
    """
    w = self.model(context)
    return np.log(w) if w > 0 else float("-inf")

logw_next(context) async

Compute the next token log weights given context.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the WCFG's alphabet.

required

Returns:

Type Description
LazyWeights

The log weights for the next tokens and EOS given context.

Source code in genlm/control/potential/built_in/wcfg.py
async def logw_next(self, context):
    """
    Compute the next token log weights given `context`.

    Args:
        context (list): A sequence of tokens in the WCFG's alphabet.

    Returns:
        (LazyWeights): The log weights for the next tokens and EOS given `context`.
    """
    ws = self.model.next_token_weights(self.model.chart(context))
    ws = ws.trim().normalize()

    ws_array = np.array([ws[x] for x in self.vocab_eos])
    mask = ws_array > 0
    log_ws = np.full_like(ws_array, float("-inf"), dtype=np.float64)
    log_ws[mask] = np.log(ws_array[mask])

    return self.make_lazy_weights(log_ws)

clear_cache()

Clear the internal cache of the parser.

Source code in genlm/control/potential/built_in/wcfg.py
def clear_cache(self):
    """Clear the internal cache of the parser."""
    self.model.clear_cache()

spawn()

Spawn a new WCFG.

Source code in genlm/control/potential/built_in/wcfg.py
def spawn(self):
    """Spawn a new WCFG."""
    return WCFG(self.cfg)

CanonicalTokenization

Bases: Potential

A custom potential that enforces canonical BPE tokenization.

This potential ensures that tokens follow the canonical tokenization rules by using the FastCanonicalityFilterBPE under the hood.

Source code in genlm/control/potential/built_in/canonical.py
class CanonicalTokenization(Potential):
    """
    A custom potential that enforces canonical BPE tokenization.

    This potential ensures that tokens follow the canonical tokenization rules
    by using the FastCanonicalityFilterBPE under the hood.
    """

    def __init__(self, canonicality_filter):
        """
        Initialize the Canonical Potential

        Args:
            canonicality_filter (FastCanonicalityFilterBPE): An initialized FastCanonicalityFilterBPE instance.
        """
        # Store the pre-initialized filter and tokenizer
        self.canonicality_filter = canonicality_filter

        # IMPORTANT: In the base Potential class, EOS will be added to vocab automatically
        # So we should NOT add it ourselves to the vocabulary we pass to super().__init__
        vocabulary = self.canonicality_filter._decode
        super().__init__(vocabulary)

    @classmethod
    def from_llm(cls, llm):
        """
        Factory method to create CanonicalTokenization from a PromptedLLM instance.

        Args:
            llm (PromptedLLM): An instance of PromptedLLM containing the model and tokenizer.

        Returns:
            (CanonicalTokenization): An initialized CanonicalTokenization instance.
        """
        if not isinstance(llm, PromptedLLM):
            raise TypeError(
                f"Expected llm to be an instance of PromptedLLM, got {type(llm)}"
            )

        # Extract necessary components from llm
        tokenizer = llm.model.tokenizer
        eos_token_ids = llm.token_maps.eos_idxs
        model_name = tokenizer.name_or_path

        # Create the filter using its factory method
        canonicality_filter = FastCanonicalityFilterBPE.from_tokenizer(
            tokenizer, eos_token_ids
        )

        # Set overrides on the filter
        canonicality_filter.set_overrides(model_name)

        # Call __init__ with the created filter and tokenizer
        return cls(canonicality_filter)

    async def complete(self, context):
        """
        Assess if a complete sequence follows canonical tokenization.

        Args:
            context (list): Sequence of tokens

        Returns:
            (float): 0.0 if canonical, float('-inf') otherwise
        """
        # Empty sequences are considered canonical
        if not context:
            return 0.0

        # Check if the sequence is canonical
        is_canonical = self._check_canonicality(context)
        return 0.0 if is_canonical else float("-inf")

    async def prefix(self, context):
        """
        Assess if a prefix sequence could potentially extend to a canonical sequence.
        For canonicality, this is the same as complete.

        Args:
            context (list): Sequence of tokens

        Returns:
            (float): 0.0 if potentially canonical, float('-inf') otherwise
        """
        return await self.complete(context)

    async def logw_next(self, context):
        """
        Compute weights for each possible next token given the context.

        Args:
            context (list): Sequence of tokens

        Returns:
            (LazyWeights): Weights for each token in the vocabulary and EOS
        """
        # Get the prefix weight (to check if context itself is canonical)
        ctx_log_w = await self.prefix(context)

        if ctx_log_w == float("-inf"):
            raise ValueError("Context is non-canonical")
        else:
            if context:
                t = (None, context[-1])
                filter_mask = self.canonicality_filter(t)
            else:
                filter_mask = np.ones(len(self.canonicality_filter._decode), dtype=bool)

            # Create log weights directly instead of using np.log(filter_mask)
            # This is more efficient, avoids torch (with torch can't combine with other potentials!)
            logws_no_eos = np.where(filter_mask, 0.0, float("-inf")).astype(np.float32)

            # append eos to the logws, always allow eos.
            # NOTE: concat is because ._decode does not include eos while .vocab_eos does
            logws = np.concatenate([logws_no_eos, np.array([0.0], dtype=np.float32)])

        return self.make_lazy_weights(logws)

    def _check_canonicality(self, context):
        """
        Check if a sequence follows canonical tokenization.

        Args:
            context (list): Sequence of tokens

        Returns:
            (bool): True if the sequence is canonical, False otherwise
        """
        # If we're checking a single token, it's always canonical
        if len(context) <= 1:
            return True

        # Check all adjacent token pairs for canonicality
        for i in range(1, len(context)):
            prev_token = context[i - 1]
            current_token = context[i]

            # Format expected by the filter: (None, previous_token)
            t = (None, prev_token)
            mask = self.canonicality_filter(t)
            # print("percent of mask: ", np.sum(mask)*100 / len(mask))

            # Find token_id in the canonicality filter's vocabulary
            token_id = self.canonicality_filter._encode[current_token]
            if not mask[token_id]:
                return False

        return True

__init__(canonicality_filter)

Initialize the Canonical Potential

Parameters:

Name Type Description Default
canonicality_filter FastCanonicalityFilterBPE

An initialized FastCanonicalityFilterBPE instance.

required
Source code in genlm/control/potential/built_in/canonical.py
def __init__(self, canonicality_filter):
    """
    Initialize the Canonical Potential

    Args:
        canonicality_filter (FastCanonicalityFilterBPE): An initialized FastCanonicalityFilterBPE instance.
    """
    # Store the pre-initialized filter and tokenizer
    self.canonicality_filter = canonicality_filter

    # IMPORTANT: In the base Potential class, EOS will be added to vocab automatically
    # So we should NOT add it ourselves to the vocabulary we pass to super().__init__
    vocabulary = self.canonicality_filter._decode
    super().__init__(vocabulary)

from_llm(llm) classmethod

Factory method to create CanonicalTokenization from a PromptedLLM instance.

Parameters:

Name Type Description Default
llm PromptedLLM

An instance of PromptedLLM containing the model and tokenizer.

required

Returns:

Type Description
CanonicalTokenization

An initialized CanonicalTokenization instance.

Source code in genlm/control/potential/built_in/canonical.py
@classmethod
def from_llm(cls, llm):
    """
    Factory method to create CanonicalTokenization from a PromptedLLM instance.

    Args:
        llm (PromptedLLM): An instance of PromptedLLM containing the model and tokenizer.

    Returns:
        (CanonicalTokenization): An initialized CanonicalTokenization instance.
    """
    if not isinstance(llm, PromptedLLM):
        raise TypeError(
            f"Expected llm to be an instance of PromptedLLM, got {type(llm)}"
        )

    # Extract necessary components from llm
    tokenizer = llm.model.tokenizer
    eos_token_ids = llm.token_maps.eos_idxs
    model_name = tokenizer.name_or_path

    # Create the filter using its factory method
    canonicality_filter = FastCanonicalityFilterBPE.from_tokenizer(
        tokenizer, eos_token_ids
    )

    # Set overrides on the filter
    canonicality_filter.set_overrides(model_name)

    # Call __init__ with the created filter and tokenizer
    return cls(canonicality_filter)

complete(context) async

Assess if a complete sequence follows canonical tokenization.

Parameters:

Name Type Description Default
context list

Sequence of tokens

required

Returns:

Type Description
float

0.0 if canonical, float('-inf') otherwise

Source code in genlm/control/potential/built_in/canonical.py
async def complete(self, context):
    """
    Assess if a complete sequence follows canonical tokenization.

    Args:
        context (list): Sequence of tokens

    Returns:
        (float): 0.0 if canonical, float('-inf') otherwise
    """
    # Empty sequences are considered canonical
    if not context:
        return 0.0

    # Check if the sequence is canonical
    is_canonical = self._check_canonicality(context)
    return 0.0 if is_canonical else float("-inf")

prefix(context) async

Assess if a prefix sequence could potentially extend to a canonical sequence. For canonicality, this is the same as complete.

Parameters:

Name Type Description Default
context list

Sequence of tokens

required

Returns:

Type Description
float

0.0 if potentially canonical, float('-inf') otherwise

Source code in genlm/control/potential/built_in/canonical.py
async def prefix(self, context):
    """
    Assess if a prefix sequence could potentially extend to a canonical sequence.
    For canonicality, this is the same as complete.

    Args:
        context (list): Sequence of tokens

    Returns:
        (float): 0.0 if potentially canonical, float('-inf') otherwise
    """
    return await self.complete(context)

logw_next(context) async

Compute weights for each possible next token given the context.

Parameters:

Name Type Description Default
context list

Sequence of tokens

required

Returns:

Type Description
LazyWeights

Weights for each token in the vocabulary and EOS

Source code in genlm/control/potential/built_in/canonical.py
async def logw_next(self, context):
    """
    Compute weights for each possible next token given the context.

    Args:
        context (list): Sequence of tokens

    Returns:
        (LazyWeights): Weights for each token in the vocabulary and EOS
    """
    # Get the prefix weight (to check if context itself is canonical)
    ctx_log_w = await self.prefix(context)

    if ctx_log_w == float("-inf"):
        raise ValueError("Context is non-canonical")
    else:
        if context:
            t = (None, context[-1])
            filter_mask = self.canonicality_filter(t)
        else:
            filter_mask = np.ones(len(self.canonicality_filter._decode), dtype=bool)

        # Create log weights directly instead of using np.log(filter_mask)
        # This is more efficient, avoids torch (with torch can't combine with other potentials!)
        logws_no_eos = np.where(filter_mask, 0.0, float("-inf")).astype(np.float32)

        # append eos to the logws, always allow eos.
        # NOTE: concat is because ._decode does not include eos while .vocab_eos does
        logws = np.concatenate([logws_no_eos, np.array([0.0], dtype=np.float32)])

    return self.make_lazy_weights(logws)

SMC

This class implements sequential Monte Carlo (SMC) inference for controlled text generation. The generation process works as follows:

  1. Token Sampling: At each step, the unit_sampler is used to extend each particle (candidate sequence) by sampling a new token. This grows all sequences by one token at a time. The sampler also outputs an importance weight with each extension to correct for the myopic nature of token-by-token sampling.

  2. Critic Evaluation: If a critic is provided, it scores the updated sequences (via it's score method), reweighting the particles based on how well they satisfy the constraints encoded by the critic.

  3. Resampling: When the effective sample size (ESS) falls below the threshold, particles are resampled according to their weights. This helps focus computation on more promising sequences.

  4. Termination: The process continues until either:

    • All sequences reach an end-of-sequence (EOS) token

    • The maximum token length is reached

If a critic is provided, the resulting sequences are properly weighted with respect to the product of the unit sampler's target potential and the critic potential (unit_sampler.target * critic). If a critic is not provided, the resulting sequences are weighted with respect to the unit sampler's target potential.

Parameters:

Name Type Description Default
unit_sampler TokenSampler

The sampler that generates tokens.

required
critic Potential

A potential function that guides the generation process by scoring candidate sequences. Must have the same token type as the unit_sampler.

None

Raises:

Type Description
ValueError

If unit_sampler is not a TokenSampler, if critic is not a Potential, or if the token types of unit_sampler and critic don't match.

Source code in genlm/control/sampler/sequence.py
class SMC:
    """This class implements sequential Monte Carlo (SMC) inference for controlled text generation.
    The generation process works as follows:

    1. Token Sampling: At each step, the `unit_sampler` is used to extend each particle (candidate sequence)
       by sampling a new token. This grows all sequences by one token at a time. The sampler also outputs
       an importance weight with each extension to correct for the myopic nature of token-by-token sampling.

    2. Critic Evaluation: If a `critic` is provided, it scores the updated sequences (via it's `score` method),
       reweighting the particles based on how well they satisfy the constraints encoded by the critic.

    3. Resampling: When the effective sample size (ESS) falls below the threshold,
       particles are resampled according to their weights. This helps focus computation
       on more promising sequences.

    4. Termination: The process continues until either:\n
        - All sequences reach an end-of-sequence (EOS) token\n
        - The maximum token length is reached

    If a critic is provided, the resulting sequences are properly weighted with respect to the product of the unit sampler's
    target potential and the critic potential (`unit_sampler.target * critic`). If a critic is not provided,
    the resulting sequences are weighted with respect to the unit sampler's target potential.

    Args:
        unit_sampler (TokenSampler): The sampler that generates tokens.
        critic (Potential, optional): A potential function that guides the generation process
            by scoring candidate sequences. Must have the same token type as the unit_sampler.

    Raises:
        ValueError: If unit_sampler is not a TokenSampler, if critic is not a Potential,
            or if the token types of unit_sampler and critic don't match.
    """

    def __init__(self, unit_sampler, critic=None):
        if not isinstance(unit_sampler, TokenSampler):
            raise ValueError("`unit_sampler` must be a TokenSampler")

        if critic:
            if not isinstance(critic, Potential):
                raise ValueError("`critic` must be a Potential")
            if not unit_sampler.token_type == critic.token_type:
                raise ValueError(
                    "`critic` must have the same token type as the `unit_sampler`. "
                    f"Got {unit_sampler.token_type} and {critic.token_type}."
                    + (
                        "\nMaybe you forgot to coerce the critic to the token type of the unit sampler? See `Coerce`."
                        if unit_sampler.token_type.is_iterable_of(critic.token_type)
                        else ""
                    )
                )

        self.unit_sampler = unit_sampler
        self.critic = critic

    async def __call__(
        self,
        n_particles,
        ess_threshold,
        max_tokens,
        verbosity=0,
        json_path=None,
        **kwargs,
    ):
        """Generate sequences using sequential Monte Carlo inference.

        Args:
            n_particles (int): Number of particles (candidate sequences) to maintain during
                generation. Higher values provide better exploration but require more
                computation.
            ess_threshold (float): Effective sample size threshold for resampling,
                expressed as a fraction of the number of particles. When ESS falls below
                this value, particles are resampled according to their weights. Should be between 0 and 1.
                Higher values lead to more frequent resampling. Note that when ess_threshold = 0,
                the critic is only applied at the end of the generation (if it is provided).
            max_tokens (int): Maximum number of tokens to generate per sequence. Generation
                may terminate earlier if all sequences reach an EOS token.
            verbosity (int, optional): Verbosity level for the SMC algorithm. 0 is silent, 1 prints the
                particles at each step. Default is 0.
            json_path (str, optional): JSON file path for saving a record of the inference run.
                This can be used in conjunction with the `InferenceVisualizer` to visualize the inference run.
            **kwargs (dict): Additional keyword arguments to pass to the SMC algorithm.
                See the `llamppl.inference.smc_standard` documentation for more details.

        Returns:
            (Sequences): A container holding the generated sequences, their importance weights, and
                other metadata from the generation process.
        """
        model = SequenceModel(
            unit_sampler=self.unit_sampler,
            critic=self.critic,
            max_tokens=max_tokens,
            verbosity=verbosity,
            twist_with_critic=ess_threshold > 0,
        )

        particles = await smc_standard(
            model=model,
            n_particles=n_particles,
            ess_threshold=ess_threshold,
            json_file=json_path,
            **kwargs,
        )

        return Sequences(*_unpack_particles(particles))

    async def cleanup(self):
        """Clean up resources used by the inference engine.

        This method should be called when the InferenceEngine is no longer needed.

        Example:
            ```python
            sampler = SequenceSampler(unit_sampler, critic)
            try:
                sequences = await sampler(n_particles=10, ess_threshold=0.5, max_tokens=20)
            finally:
                await sampler.cleanup()
            ```
        """
        await self.unit_sampler.cleanup()
        if self.critic:
            await self.critic.cleanup()

__call__(n_particles, ess_threshold, max_tokens, verbosity=0, json_path=None, **kwargs) async

Generate sequences using sequential Monte Carlo inference.

Parameters:

Name Type Description Default
n_particles int

Number of particles (candidate sequences) to maintain during generation. Higher values provide better exploration but require more computation.

required
ess_threshold float

Effective sample size threshold for resampling, expressed as a fraction of the number of particles. When ESS falls below this value, particles are resampled according to their weights. Should be between 0 and 1. Higher values lead to more frequent resampling. Note that when ess_threshold = 0, the critic is only applied at the end of the generation (if it is provided).

required
max_tokens int

Maximum number of tokens to generate per sequence. Generation may terminate earlier if all sequences reach an EOS token.

required
verbosity int

Verbosity level for the SMC algorithm. 0 is silent, 1 prints the particles at each step. Default is 0.

0
json_path str

JSON file path for saving a record of the inference run. This can be used in conjunction with the InferenceVisualizer to visualize the inference run.

None
**kwargs dict

Additional keyword arguments to pass to the SMC algorithm. See the llamppl.inference.smc_standard documentation for more details.

{}

Returns:

Type Description
Sequences

A container holding the generated sequences, their importance weights, and other metadata from the generation process.

Source code in genlm/control/sampler/sequence.py
async def __call__(
    self,
    n_particles,
    ess_threshold,
    max_tokens,
    verbosity=0,
    json_path=None,
    **kwargs,
):
    """Generate sequences using sequential Monte Carlo inference.

    Args:
        n_particles (int): Number of particles (candidate sequences) to maintain during
            generation. Higher values provide better exploration but require more
            computation.
        ess_threshold (float): Effective sample size threshold for resampling,
            expressed as a fraction of the number of particles. When ESS falls below
            this value, particles are resampled according to their weights. Should be between 0 and 1.
            Higher values lead to more frequent resampling. Note that when ess_threshold = 0,
            the critic is only applied at the end of the generation (if it is provided).
        max_tokens (int): Maximum number of tokens to generate per sequence. Generation
            may terminate earlier if all sequences reach an EOS token.
        verbosity (int, optional): Verbosity level for the SMC algorithm. 0 is silent, 1 prints the
            particles at each step. Default is 0.
        json_path (str, optional): JSON file path for saving a record of the inference run.
            This can be used in conjunction with the `InferenceVisualizer` to visualize the inference run.
        **kwargs (dict): Additional keyword arguments to pass to the SMC algorithm.
            See the `llamppl.inference.smc_standard` documentation for more details.

    Returns:
        (Sequences): A container holding the generated sequences, their importance weights, and
            other metadata from the generation process.
    """
    model = SequenceModel(
        unit_sampler=self.unit_sampler,
        critic=self.critic,
        max_tokens=max_tokens,
        verbosity=verbosity,
        twist_with_critic=ess_threshold > 0,
    )

    particles = await smc_standard(
        model=model,
        n_particles=n_particles,
        ess_threshold=ess_threshold,
        json_file=json_path,
        **kwargs,
    )

    return Sequences(*_unpack_particles(particles))

cleanup() async

Clean up resources used by the inference engine.

This method should be called when the InferenceEngine is no longer needed.

Example
sampler = SequenceSampler(unit_sampler, critic)
try:
    sequences = await sampler(n_particles=10, ess_threshold=0.5, max_tokens=20)
finally:
    await sampler.cleanup()
Source code in genlm/control/sampler/sequence.py
async def cleanup(self):
    """Clean up resources used by the inference engine.

    This method should be called when the InferenceEngine is no longer needed.

    Example:
        ```python
        sampler = SequenceSampler(unit_sampler, critic)
        try:
            sequences = await sampler(n_particles=10, ess_threshold=0.5, max_tokens=20)
        finally:
            await sampler.cleanup()
        ```
    """
    await self.unit_sampler.cleanup()
    if self.critic:
        await self.critic.cleanup()

direct_token_sampler(potential)

Create a DirectTokenSampler that samples directly from a potential's vocabulary.

See DirectTokenSampler for more details.

Parameters:

Name Type Description Default
potential Potential

The potential function to sample from. Should have an efficient logw_next method.

required

Returns:

Type Description
DirectTokenSampler

A sampler that directly samples tokens from the potential's vocabulary.

Source code in genlm/control/sampler/__init__.py
def direct_token_sampler(potential):
    """Create a `DirectTokenSampler` that samples directly from a potential's vocabulary.

    See `DirectTokenSampler` for more details.

    Args:
        potential (Potential): The potential function to sample from. Should have an efficient logw_next method.

    Returns:
        (DirectTokenSampler): A sampler that directly samples tokens from the potential's vocabulary.
    """
    assert isinstance(potential, Potential)
    return DirectTokenSampler(potential)

eager_token_sampler(iter_potential, item_potential)

Create a SetTokenSampler that uses the EagerSetSampler to sample a set of tokens.

See EagerSetSampler for more details.

Parameters:

Name Type Description Default
iter_potential Potential

A potential function defined over a vocabulary of iterables.

required
item_potential Potential

A potential function defined over a vocabulary of items which are elements of the iterables.

required

Returns:

Type Description
SetTokenSampler

A sampler that wraps an EagerSetSampler.

Note

This is the fastest sampler in most cases.

Source code in genlm/control/sampler/__init__.py
def eager_token_sampler(iter_potential, item_potential):
    """Create a `SetTokenSampler` that uses the `EagerSetSampler` to sample a set of tokens.

    See `EagerSetSampler` for more details.

    Args:
        iter_potential (Potential): A potential function defined over a vocabulary of iterables.
        item_potential (Potential): A potential function defined over a vocabulary of items which are elements of the iterables.

    Returns:
        (SetTokenSampler): A sampler that wraps an `EagerSetSampler`.

    Note:
        This is the fastest sampler in most cases.
    """
    return SetTokenSampler(EagerSetSampler(iter_potential, item_potential))

topk_token_sampler(iter_potential, item_potential, K)

Create a SetTokenSampler that uses the TopKSetSampler to sample a set of tokens.

See TopKSetSampler for more details.

Parameters:

Name Type Description Default
iter_potential Potential

A potential function defined over a vocabulary of iterables.

required
item_potential Potential

A potential function defined over a vocabulary of items which are elements of the iterables.

required
K int | None

The K parameter for the TopKSetSampler.

required

Returns:

Type Description
SetTokenSampler

A sampler that wraps an TopKSetSampler.

Source code in genlm/control/sampler/__init__.py
def topk_token_sampler(iter_potential, item_potential, K):
    """Create a `SetTokenSampler` that uses the `TopKSetSampler` to sample a set of tokens.

    See `TopKSetSampler` for more details.

    Args:
        iter_potential (Potential): A potential function defined over a vocabulary of iterables.
        item_potential (Potential): A potential function defined over a vocabulary of items which are elements of the iterables.
        K (int|None): The `K` parameter for the `TopKSetSampler`.

    Returns:
        (SetTokenSampler): A sampler that wraps an `TopKSetSampler`.
    """
    return SetTokenSampler(TopKSetSampler(iter_potential, item_potential, K))

AWRS

Bases: TokenSampler

Samples individual tokens through an adaptive weighted rejection sampling algorithm.

This sampler is based on the algorithm described in Fast Controlled Generation from Language Models with Adaptive Weighted Rejection Sampling

It draws properly weighted samples from the product of a non-boolean potential and a boolean condition.

Parameters:

Name Type Description Default
potential Potential

The non-boolean potential.

required
condition Potential

The boolean condition. This potential must only output boolean values (0 or -inf in log-space).

required
seed int or None

The seed for the random number generator.

None
prune_logws bool

Whether to prune the logws to only include the tokens in the intersection of the potential and condition vocabularies

True
proper_weights bool

Whether to return properly weighted samples. If False, the sampler will only run one round of adaptive rejection sampling.

True
max_accepts int

The maximum number of tokens to accept - higher values will decrease the variance of the weight estimate.

2
max_rejects int or float('inf'

The maximum number of tokens to reject - lower values will run faster, but at the cost of returning a weight of zero for some samples where there are tokens that would be accepted if tested.

float('inf')
n_monte_carlo_samples int

The number of Monte Carlo samples to use to estimate the weight. Higher values will decrease the variance of the weight estimate, but will run slower.

None
Source code in genlm/control/sampler/token.py
class AWRS(TokenSampler):
    """Samples individual tokens through an adaptive weighted rejection sampling algorithm.

    This sampler is based on the algorithm described in [Fast Controlled Generation from Language Models with Adaptive Weighted Rejection Sampling](https://arxiv.org/abs/2504.05410)

    It draws properly weighted samples from the product of a non-boolean potential and a boolean condition.

    Args:
        potential (Potential): The non-boolean potential.
        condition (Potential): The boolean condition. This potential must only output boolean values (0 or -inf in log-space).
        seed (int or None): The seed for the random number generator.
        prune_logws (bool): Whether to prune the logws to only include the tokens in the intersection of the potential and condition vocabularies
        proper_weights (bool): Whether to return properly weighted samples.
            If False, the sampler will only run one round of adaptive rejection sampling.
        max_accepts (int): The maximum number of tokens to accept - higher values will decrease the variance of the weight estimate.
        max_rejects (int or float('inf')): The maximum number of tokens to reject - lower values will run faster, but at the cost of returning a weight of zero for some samples where there are tokens that would be accepted if tested.
        n_monte_carlo_samples (int): The number of Monte Carlo samples to use to estimate the weight. Higher values will decrease the variance of the weight estimate, but will run slower.
    """

    def __init__(
        self,
        potential,
        condition,
        seed=None,
        prune_logws=True,
        proper_weights=True,
        max_accepts=2,
        max_rejects=float("inf"),
        n_monte_carlo_samples=None,
    ):
        super().__init__(target=potential * condition)
        self.potential = potential
        self.condition = condition

        self.prune_logws = prune_logws
        self.proper_weights = proper_weights

        if max_accepts < 2 and proper_weights:
            raise ValueError("`max_accepts` must be at least 2")

        if max_rejects < 2 and proper_weights:
            raise ValueError("`max_rejects` must be at least 2")

        if n_monte_carlo_samples is not None:
            warnings.warn(
                "n_monte_carlo_samples no longer does anything.",
                DeprecationWarning,
            )

        self.max_accepts = max_accepts
        self.max_rejects = max_rejects or float("inf")

        self.valid_idxs = np.array(
            [self.potential.lookup[t] for t in self.target.vocab_eos]
        )

        self.vocab_eos_set = set(self.target.vocab_eos)
        self.V = len(self.potential.vocab_eos)
        self.rng = np.random.default_rng(seed=seed)

    def _prune_logws(self, logws):
        # Prune the logws to only include the tokens in the
        # target vocabulary. (This zeros-out tokens which we know a priori
        # will be rejected.) Note: We need an additional correction term
        # to account for the fact that we're throwing away some probability mass.
        # This should be handled in `sample`.
        pruned = self.potential.alloc_logws()
        pruned[self.valid_idxs] = logws.weights[self.valid_idxs]
        logws.weights = pruned
        return logws

    async def _accept(self, context, token, verbosity=0):
        if self.prune_logws or token in self.vocab_eos_set:
            if token is self.target.eos:
                logscore = await self.condition.complete(context)
            else:
                logscore = await self.condition.prefix(context + [token])
            assert logscore in {-np.inf, 0}, "`condition` must be Boolean"
        else:
            logscore = -np.inf

        do_accept = logscore == 0

        if verbosity > 0:
            if do_accept:
                print(colors.green % f". {repr(token)}")
            else:
                print(colors.red % ".", end="")

        return do_accept

    async def sample(self, context, verbosity=0):
        """Sample a token and weight that are properly weighted with respect to the target potential's `logw_next` method via adaptive weighted rejection sampling.

        The returned weight corresponds to the log normalizing constant of $\\textsf{target.logw_next}(x_n | x_1, \\ldots, x_{n-1})$.

        Returns:
            (token, weight, np.nan): A tuple containing the sampled token, weight, and a dummy value for the log-probability of the sampled token.
        """
        logws = await self.potential.logw_next(context)
        if self.prune_logws:
            logws = self._prune_logws(logws)

        logZ = logsumexp(logws.weights)
        logps = logws.weights - logZ
        toks = logws.decode

        # We cache successful calls, as algorithms may want to see the
        # same successful token more than once (currently just geometric_awrs)
        cache = {}

        async def accept(tok):
            try:
                return cache[tok]
            except KeyError:
                pass
            result = await self._accept(context, tok, verbosity)
            if result:
                cache[tok] = result
            return result

        if not self.proper_weights:
            return await improper_sample(
                logps=logps,
                toks=toks,
                accept=accept,
                rng=self.rng,
                max_rejects=self.max_rejects,
            )
        # We pick which algorithm to use based on parameters and the
        # shape of the distribution, as this lets us pick the most
        # effective option.
        elif (
            # If max_accepts is large then recursive_awrs (which
            # does not currently support this parameter) isn't very
            # useful, because the recursive step means that you never
            # revisit the same value, so will often throw away most
            # of the accepted mass if you were to continue. Also
            # this parameter is only really relevant if you want to
            # lower the variance, and geometric_awrs is lower variance.
            self.max_accepts > 2
            or
            # If the distribution is strongly peaked around a single value
            # then geometric_awrs will be more efficient. See below
            # for specific derivation.
            logps.max() >= GEOMETRIC_THRESHOLD
        ):
            tok, w, _ = await geometric_awrs(
                logps=logps,
                toks=toks,
                accept=accept,
                rng=self.rng,
                max_rejects=self.max_rejects,
                max_accepts=self.max_accepts,
            )
            return tok, w + logZ, np.nan
        else:
            tok, w, _ = await recursive_awrs(
                logps=logps,
                toks=toks,
                accept=accept,
                rng=self.rng,
                max_rejects=self.max_rejects,
            )
            return tok, w + logZ, np.nan

sample(context, verbosity=0) async

Sample a token and weight that are properly weighted with respect to the target potential's logw_next method via adaptive weighted rejection sampling.

The returned weight corresponds to the log normalizing constant of \(\textsf{target.logw_next}(x_n | x_1, \ldots, x_{n-1})\).

Returns:

Type Description
(token, weight, nan)

A tuple containing the sampled token, weight, and a dummy value for the log-probability of the sampled token.

Source code in genlm/control/sampler/token.py
async def sample(self, context, verbosity=0):
    """Sample a token and weight that are properly weighted with respect to the target potential's `logw_next` method via adaptive weighted rejection sampling.

    The returned weight corresponds to the log normalizing constant of $\\textsf{target.logw_next}(x_n | x_1, \\ldots, x_{n-1})$.

    Returns:
        (token, weight, np.nan): A tuple containing the sampled token, weight, and a dummy value for the log-probability of the sampled token.
    """
    logws = await self.potential.logw_next(context)
    if self.prune_logws:
        logws = self._prune_logws(logws)

    logZ = logsumexp(logws.weights)
    logps = logws.weights - logZ
    toks = logws.decode

    # We cache successful calls, as algorithms may want to see the
    # same successful token more than once (currently just geometric_awrs)
    cache = {}

    async def accept(tok):
        try:
            return cache[tok]
        except KeyError:
            pass
        result = await self._accept(context, tok, verbosity)
        if result:
            cache[tok] = result
        return result

    if not self.proper_weights:
        return await improper_sample(
            logps=logps,
            toks=toks,
            accept=accept,
            rng=self.rng,
            max_rejects=self.max_rejects,
        )
    # We pick which algorithm to use based on parameters and the
    # shape of the distribution, as this lets us pick the most
    # effective option.
    elif (
        # If max_accepts is large then recursive_awrs (which
        # does not currently support this parameter) isn't very
        # useful, because the recursive step means that you never
        # revisit the same value, so will often throw away most
        # of the accepted mass if you were to continue. Also
        # this parameter is only really relevant if you want to
        # lower the variance, and geometric_awrs is lower variance.
        self.max_accepts > 2
        or
        # If the distribution is strongly peaked around a single value
        # then geometric_awrs will be more efficient. See below
        # for specific derivation.
        logps.max() >= GEOMETRIC_THRESHOLD
    ):
        tok, w, _ = await geometric_awrs(
            logps=logps,
            toks=toks,
            accept=accept,
            rng=self.rng,
            max_rejects=self.max_rejects,
            max_accepts=self.max_accepts,
        )
        return tok, w + logZ, np.nan
    else:
        tok, w, _ = await recursive_awrs(
            logps=logps,
            toks=toks,
            accept=accept,
            rng=self.rng,
            max_rejects=self.max_rejects,
        )
        return tok, w + logZ, np.nan

InferenceVisualizer

Web-based visualization server for SMC inference results.

This class is intended to be used in conjunction with the InferenceEngine class.

Example
from genlm.control import InferenceVisualizer
# create the visualizer
viz = InferenceVisualizer()
# run inference and save the record to a JSON file
sequences = await token_sampler.smc(
    n_particles=10,
    max_tokens=20,
    ess_threshold=0.5,
    json_path="smc_record.json" # save the record to a JSON file
)
# visualize the inference run
viz.visualize("smc_record.json")
# clean up visualization server
viz.shutdown_server()
Source code in genlm/control/viz.py
class InferenceVisualizer:
    """Web-based visualization server for SMC inference results.

    This class is intended to be used in conjunction with the `InferenceEngine` class.

    Example:
        ```python
        from genlm.control import InferenceVisualizer
        # create the visualizer
        viz = InferenceVisualizer()
        # run inference and save the record to a JSON file
        sequences = await token_sampler.smc(
            n_particles=10,
            max_tokens=20,
            ess_threshold=0.5,
            json_path="smc_record.json" # save the record to a JSON file
        )
        # visualize the inference run
        viz.visualize("smc_record.json")
        # clean up visualization server
        viz.shutdown_server()
        ```
    """

    def __init__(self, port=8000, serve_dir=None):
        """Initialize the visualization server.

        Args:
            port (int): Port to run the server on.
            serve_dir (str | Path, optional): Directory to serve files from.
                If None, creates a temporary directory.

        Raises:
            OSError: If the port is already in use
        """
        self._server = None
        self._server_thread = None
        self._port = port
        self._html_dir = Path(__file__).parent / "html"

        # Set up serve directory
        if serve_dir is None:
            self._serve_dir = Path(tempfile.mkdtemp(prefix="smc_viz_"))
            self._using_temp_dir = True
        else:
            self._serve_dir = Path(serve_dir).resolve()
            self._using_temp_dir = False
            self._serve_dir.mkdir(exist_ok=True)

        # Create handler that serves from both directories
        class Handler(http.server.SimpleHTTPRequestHandler):
            def translate_path(self_, path):
                # Remove query parameters for file lookup
                clean_path = path.split("?")[0]
                # HTML files come from package
                if clean_path.endswith(".html"):
                    return str(self._html_dir / clean_path.lstrip("/"))
                # JSON files come from serve directory
                return str(self._serve_dir / clean_path.lstrip("/"))

        self._start_server(Handler)

    def visualize(self, json_path, auto_open=False):
        """Visualize the inference run in a browser.

        Args:
            json_path (str | Path): Path to the JSON file to visualize. If the file is not
                in the serve directory, it will be copied there. For efficiency, you can
                write JSON files directly to the serve directory
            auto_open (bool): Whether to automatically open in browser

        Returns:
            (str): URL where visualization can be accessed
        """
        if self._server is None:
            raise RuntimeError("Server is not running")

        json_path = Path(json_path)
        if not json_path.exists():
            raise FileNotFoundError(f"JSON file not found: {json_path}")

        # If file isn't in serve directory, copy it there
        dest_path = self._serve_dir / json_path.name
        if json_path.resolve() != dest_path.resolve():
            shutil.copy2(json_path, dest_path)

        url = f"http://localhost:{self._port}/smc.html?path={json_path.name}"

        if auto_open:
            webbrowser.open(url)

        return url

    def _start_server(self, handler_class):
        """Start the HTTP server."""
        try:
            self._server = socketserver.TCPServer(
                ("", self._port), handler_class, bind_and_activate=False
            )
            self._server.allow_reuse_address = True
            self._server.server_bind()
            self._server.server_activate()
        except OSError as e:
            if e.errno == 48 or e.errno == 98:  # Address already in use
                raise OSError(f"Port {self._port} is already in use") from None
            raise

        self._server_thread = threading.Thread(target=self._server.serve_forever)
        self._server_thread.daemon = True
        self._server_thread.start()

    def shutdown_server(self):
        """Shut down the visualization server."""
        if self._server is not None:
            if self._server_thread is not None and self._server_thread.is_alive():
                self._server.shutdown()
                self._server_thread.join()
            self._server.server_close()
            self._server = None
            self._server_thread = None

        # Clean up any temporary files
        if self._using_temp_dir and self._serve_dir.exists():
            shutil.rmtree(self._serve_dir)

    def __del__(self):
        """Ensure server is shut down when object is deleted."""
        self.shutdown_server()

__init__(port=8000, serve_dir=None)

Initialize the visualization server.

Parameters:

Name Type Description Default
port int

Port to run the server on.

8000
serve_dir str | Path

Directory to serve files from. If None, creates a temporary directory.

None

Raises:

Type Description
OSError

If the port is already in use

Source code in genlm/control/viz.py
def __init__(self, port=8000, serve_dir=None):
    """Initialize the visualization server.

    Args:
        port (int): Port to run the server on.
        serve_dir (str | Path, optional): Directory to serve files from.
            If None, creates a temporary directory.

    Raises:
        OSError: If the port is already in use
    """
    self._server = None
    self._server_thread = None
    self._port = port
    self._html_dir = Path(__file__).parent / "html"

    # Set up serve directory
    if serve_dir is None:
        self._serve_dir = Path(tempfile.mkdtemp(prefix="smc_viz_"))
        self._using_temp_dir = True
    else:
        self._serve_dir = Path(serve_dir).resolve()
        self._using_temp_dir = False
        self._serve_dir.mkdir(exist_ok=True)

    # Create handler that serves from both directories
    class Handler(http.server.SimpleHTTPRequestHandler):
        def translate_path(self_, path):
            # Remove query parameters for file lookup
            clean_path = path.split("?")[0]
            # HTML files come from package
            if clean_path.endswith(".html"):
                return str(self._html_dir / clean_path.lstrip("/"))
            # JSON files come from serve directory
            return str(self._serve_dir / clean_path.lstrip("/"))

    self._start_server(Handler)

visualize(json_path, auto_open=False)

Visualize the inference run in a browser.

Parameters:

Name Type Description Default
json_path str | Path

Path to the JSON file to visualize. If the file is not in the serve directory, it will be copied there. For efficiency, you can write JSON files directly to the serve directory

required
auto_open bool

Whether to automatically open in browser

False

Returns:

Type Description
str

URL where visualization can be accessed

Source code in genlm/control/viz.py
def visualize(self, json_path, auto_open=False):
    """Visualize the inference run in a browser.

    Args:
        json_path (str | Path): Path to the JSON file to visualize. If the file is not
            in the serve directory, it will be copied there. For efficiency, you can
            write JSON files directly to the serve directory
        auto_open (bool): Whether to automatically open in browser

    Returns:
        (str): URL where visualization can be accessed
    """
    if self._server is None:
        raise RuntimeError("Server is not running")

    json_path = Path(json_path)
    if not json_path.exists():
        raise FileNotFoundError(f"JSON file not found: {json_path}")

    # If file isn't in serve directory, copy it there
    dest_path = self._serve_dir / json_path.name
    if json_path.resolve() != dest_path.resolve():
        shutil.copy2(json_path, dest_path)

    url = f"http://localhost:{self._port}/smc.html?path={json_path.name}"

    if auto_open:
        webbrowser.open(url)

    return url

shutdown_server()

Shut down the visualization server.

Source code in genlm/control/viz.py
def shutdown_server(self):
    """Shut down the visualization server."""
    if self._server is not None:
        if self._server_thread is not None and self._server_thread.is_alive():
            self._server.shutdown()
            self._server_thread.join()
        self._server.server_close()
        self._server = None
        self._server_thread = None

    # Clean up any temporary files
    if self._using_temp_dir and self._serve_dir.exists():
        shutil.rmtree(self._serve_dir)

__del__()

Ensure server is shut down when object is deleted.

Source code in genlm/control/viz.py
def __del__(self):
    """Ensure server is shut down when object is deleted."""
    self.shutdown_server()