Skip to content

sampler

DirectTokenSampler

Bases: TokenSampler

Samples individual tokens directly from the log-normalized logw_next function of a potential.

Parameters:

Name Type Description Default
potential Potential

The potential function to sample from

required
Warning

Only use this sampler if the potential's logw_next method is efficient. This is the case for potentials like PromptedLLM, but for custom potentials with a large vocabulary size, the default implementation of logw_next generally will not be efficient, and thus this sampler will be slow.

Source code in genlm/control/sampler/token.py
class DirectTokenSampler(TokenSampler):
    """Samples individual tokens directly from the log-normalized `logw_next` function
    of a potential.

    Args:
        potential (Potential): The potential function to sample from

    Warning:
        Only use this sampler if the potential's `logw_next` method is efficient. This is the case
        for potentials like `PromptedLLM`, but for custom potentials with a large vocabulary size,
        the default implementation of `logw_next` generally will not be efficient, and thus this
        sampler will be slow.
    """

    def __init__(self, potential):
        super().__init__(target=potential)
        self.potential = potential

    async def sample(self, context, draw=None):
        """Sample a token and weight that are properly weighted with respect to the target potential's `logw_next` method.

        Given a context of tokens $x_1, \\ldots, x_{n-1}$ in the target potential's vocabulary,
        this method samples a token $x_n \\in \\textsf{target.vocab_eos}$ and weight $w$.

        The sampled token and weight are properly weighted with respect to
        $$
        \\textsf{target.logw_next}(x_n | x_1, \\ldots, x_{n-1})
        $$

        The returned weight corresponds to the log normalizing constant of $\\textsf{target.logw_next}(x_n | x_1, \\ldots, x_{n-1})$.

        Returns:
            (token, weight, logp): A tuple containing the sampled token, weight, and log-probability of the sampled token.
        """
        logws = await self.potential.logw_next(context)
        logps = logws.normalize()
        if draw is None:
            # fast sampling from logps using gumbel-max trick
            token = fast_sample_lazyweights(logps)
        else:
            token = draw(logps.exp().materialize())
        return token, logws.sum(), logps[token]

    async def cleanup(self):
        pass  # pragma: no cover

sample(context, draw=None) async

Sample a token and weight that are properly weighted with respect to the target potential's logw_next method.

Given a context of tokens \(x_1, \ldots, x_{n-1}\) in the target potential's vocabulary, this method samples a token \(x_n \in \textsf{target.vocab_eos}\) and weight \(w\).

The sampled token and weight are properly weighted with respect to $$ \textsf{target.logw_next}(x_n | x_1, \ldots, x_{n-1}) $$

The returned weight corresponds to the log normalizing constant of \(\textsf{target.logw_next}(x_n | x_1, \ldots, x_{n-1})\).

Returns:

Type Description
(token, weight, logp)

A tuple containing the sampled token, weight, and log-probability of the sampled token.

Source code in genlm/control/sampler/token.py
async def sample(self, context, draw=None):
    """Sample a token and weight that are properly weighted with respect to the target potential's `logw_next` method.

    Given a context of tokens $x_1, \\ldots, x_{n-1}$ in the target potential's vocabulary,
    this method samples a token $x_n \\in \\textsf{target.vocab_eos}$ and weight $w$.

    The sampled token and weight are properly weighted with respect to
    $$
    \\textsf{target.logw_next}(x_n | x_1, \\ldots, x_{n-1})
    $$

    The returned weight corresponds to the log normalizing constant of $\\textsf{target.logw_next}(x_n | x_1, \\ldots, x_{n-1})$.

    Returns:
        (token, weight, logp): A tuple containing the sampled token, weight, and log-probability of the sampled token.
    """
    logws = await self.potential.logw_next(context)
    logps = logws.normalize()
    if draw is None:
        # fast sampling from logps using gumbel-max trick
        token = fast_sample_lazyweights(logps)
    else:
        token = draw(logps.exp().materialize())
    return token, logws.sum(), logps[token]

SetTokenSampler

Bases: TokenSampler

Samples individual tokens by sampling a weighted set of tokens and then selecting one proportional to its weight.

This class wraps a SetSampler.

Parameters:

Name Type Description Default
set_sampler SetSampler

The set sampler to sample from

required
Source code in genlm/control/sampler/token.py
class SetTokenSampler(TokenSampler):
    """Samples individual tokens by sampling a weighted set of tokens and then selecting one
    proportional to its weight.

    This class wraps a `SetSampler`.

    Args:
        set_sampler (SetSampler): The set sampler to sample from
    """

    def __init__(self, set_sampler):
        assert isinstance(set_sampler, SetSampler)
        super().__init__(set_sampler.target)
        self.set_sampler = set_sampler

    async def sample(self, context, draw=None):
        """Sample a token and weight by sampling a weighted set of tokens from the `set_sampler`
        and then selecting one proportional to its weight.

        Given a context of tokens $x_1, \\ldots, x_{n-1}$ in the vocabulary of the set sampler's target potential,
        this method samples a token $x_n \\in \\textsf{set_sampler.target.vocab_eos}$ and a weight.

        The sampled token and weight are properly weighted with respect to
        $$
        \\textsf{set_sampler.target.logw_next}(x_n | x_1, \\ldots, x_{n-1})
        $$

        The returned weight corresponds to the sum of the weights of the sampled set.

        Args:
            context (list[int]): A sequence of tokens in the vocabulary of the set sampler's target potential.

        Returns:
            (token, weight, logp): A tuple containing the sampled token, weight, and log-probability of the random
                choices made in sampling that token.

        Note:
            For properly weighted sampling, the `set_sampler` must assign correct weights to each token. See
            `SetSampler` for more details.
        """
        logws, logp = await self.set_sampler.sample_set(context, draw=draw)
        logps = logws.normalize()
        if draw is None:
            token = fast_sample_lazyweights(logps)
        else:
            token = draw(logps.exp().materialize())
        return token, logws.sum(), logp + logps[token]

    async def cleanup(self):
        """Clean up the sampler.

        This method should be called when the sampler is no longer needed.
        """
        await self.set_sampler.cleanup()

sample(context, draw=None) async

Sample a token and weight by sampling a weighted set of tokens from the set_sampler and then selecting one proportional to its weight.

Given a context of tokens \(x_1, \ldots, x_{n-1}\) in the vocabulary of the set sampler's target potential, this method samples a token \(x_n \in \textsf{set_sampler.target.vocab_eos}\) and a weight.

The sampled token and weight are properly weighted with respect to $$ \textsf{set_sampler.target.logw_next}(x_n | x_1, \ldots, x_{n-1}) $$

The returned weight corresponds to the sum of the weights of the sampled set.

Parameters:

Name Type Description Default
context list[int]

A sequence of tokens in the vocabulary of the set sampler's target potential.

required

Returns:

Type Description
(token, weight, logp)

A tuple containing the sampled token, weight, and log-probability of the random choices made in sampling that token.

Note

For properly weighted sampling, the set_sampler must assign correct weights to each token. See SetSampler for more details.

Source code in genlm/control/sampler/token.py
async def sample(self, context, draw=None):
    """Sample a token and weight by sampling a weighted set of tokens from the `set_sampler`
    and then selecting one proportional to its weight.

    Given a context of tokens $x_1, \\ldots, x_{n-1}$ in the vocabulary of the set sampler's target potential,
    this method samples a token $x_n \\in \\textsf{set_sampler.target.vocab_eos}$ and a weight.

    The sampled token and weight are properly weighted with respect to
    $$
    \\textsf{set_sampler.target.logw_next}(x_n | x_1, \\ldots, x_{n-1})
    $$

    The returned weight corresponds to the sum of the weights of the sampled set.

    Args:
        context (list[int]): A sequence of tokens in the vocabulary of the set sampler's target potential.

    Returns:
        (token, weight, logp): A tuple containing the sampled token, weight, and log-probability of the random
            choices made in sampling that token.

    Note:
        For properly weighted sampling, the `set_sampler` must assign correct weights to each token. See
        `SetSampler` for more details.
    """
    logws, logp = await self.set_sampler.sample_set(context, draw=draw)
    logps = logws.normalize()
    if draw is None:
        token = fast_sample_lazyweights(logps)
    else:
        token = draw(logps.exp().materialize())
    return token, logws.sum(), logp + logps[token]

cleanup() async

Clean up the sampler.

This method should be called when the sampler is no longer needed.

Source code in genlm/control/sampler/token.py
async def cleanup(self):
    """Clean up the sampler.

    This method should be called when the sampler is no longer needed.
    """
    await self.set_sampler.cleanup()

AWRS

Bases: TokenSampler

Samples individual tokens through an adaptive weighted rejection sampling algorithm.

This sampler is based on the algorithm described in Fast Controlled Generation from Language Models with Adaptive Weighted Rejection Sampling

It draws properly weighted samples from the product of a non-boolean potential and a boolean condition.

Parameters:

Name Type Description Default
potential Potential

The non-boolean potential.

required
condition Potential

The boolean condition. This potential must only output boolean values (0 or -inf in log-space).

required
seed int or None

The seed for the random number generator.

None
prune_logws bool

Whether to prune the logws to only include the tokens in the intersection of the potential and condition vocabularies

True
proper_weights bool

Whether to return properly weighted samples. If False, the sampler will only run one round of adaptive rejection sampling.

True
max_accepts int

The maximum number of tokens to accept - higher values will decrease the variance of the weight estimate.

2
max_rejects int or float(inf)

The maximum number of tokens to reject - lower values will run faster, but at the cost of returning a weight of zero for some samples where there are tokens that would be accepted if tested.

float('inf')
n_monte_carlo_samples int

The number of Monte Carlo samples to use to estimate the weight. Higher values will decrease the variance of the weight estimate, but will run slower.

None
Source code in genlm/control/sampler/token.py
class AWRS(TokenSampler):
    """Samples individual tokens through an adaptive weighted rejection sampling algorithm.

    This sampler is based on the algorithm described in [Fast Controlled Generation from Language Models with Adaptive Weighted Rejection Sampling](https://arxiv.org/abs/2504.05410)

    It draws properly weighted samples from the product of a non-boolean potential and a boolean condition.

    Args:
        potential (Potential): The non-boolean potential.
        condition (Potential): The boolean condition. This potential must only output boolean values (0 or -inf in log-space).
        seed (int or None): The seed for the random number generator.
        prune_logws (bool): Whether to prune the logws to only include the tokens in the intersection of the potential and condition vocabularies
        proper_weights (bool): Whether to return properly weighted samples.
            If False, the sampler will only run one round of adaptive rejection sampling.
        max_accepts (int): The maximum number of tokens to accept - higher values will decrease the variance of the weight estimate.
        max_rejects (int or float('inf')): The maximum number of tokens to reject - lower values will run faster, but at the cost of returning a weight of zero for some samples where there are tokens that would be accepted if tested.
        n_monte_carlo_samples (int): The number of Monte Carlo samples to use to estimate the weight. Higher values will decrease the variance of the weight estimate, but will run slower.
    """

    def __init__(
        self,
        potential,
        condition,
        seed=None,
        prune_logws=True,
        proper_weights=True,
        max_accepts=2,
        max_rejects=float("inf"),
        n_monte_carlo_samples=None,
    ):
        super().__init__(target=potential * condition)
        self.potential = potential
        self.condition = condition

        self.prune_logws = prune_logws
        self.proper_weights = proper_weights

        if max_accepts < 2 and proper_weights:
            raise ValueError("`max_accepts` must be at least 2")

        if max_rejects < 2 and proper_weights:
            raise ValueError("`max_rejects` must be at least 2")

        if n_monte_carlo_samples is not None:
            warnings.warn(
                "n_monte_carlo_samples no longer does anything.",
                DeprecationWarning,
            )

        self.max_accepts = max_accepts
        self.max_rejects = max_rejects or float("inf")

        self.valid_idxs = np.array(
            [self.potential.lookup[t] for t in self.target.vocab_eos]
        )

        self.vocab_eos_set = set(self.target.vocab_eos)
        self.V = len(self.potential.vocab_eos)
        self.rng = np.random.default_rng(seed=seed)

    def _prune_logws(self, logws):
        # Prune the logws to only include the tokens in the
        # target vocabulary. (This zeros-out tokens which we know a priori
        # will be rejected.) Note: We need an additional correction term
        # to account for the fact that we're throwing away some probability mass.
        # This should be handled in `sample`.
        pruned = self.potential.alloc_logws()
        pruned[self.valid_idxs] = logws.weights[self.valid_idxs]
        logws.weights = pruned
        return logws

    async def _accept(self, context, token, verbosity=0):
        if self.prune_logws or token in self.vocab_eos_set:
            if token is self.target.eos:
                logscore = await self.condition.complete(context)
            else:
                logscore = await self.condition.prefix(context + [token])
            assert logscore in {-np.inf, 0}, "`condition` must be Boolean"
        else:
            logscore = -np.inf

        do_accept = logscore == 0

        if verbosity > 0:
            if do_accept:
                print(colors.green % f". {repr(token)}")
            else:
                print(colors.red % ".", end="")

        return do_accept

    async def sample(self, context, verbosity=0):
        """Sample a token and weight that are properly weighted with respect to the target potential's `logw_next` method via adaptive weighted rejection sampling.

        The returned weight corresponds to the log normalizing constant of $\\textsf{target.logw_next}(x_n | x_1, \\ldots, x_{n-1})$.

        Returns:
            (token, weight, np.nan): A tuple containing the sampled token, weight, and a dummy value for the log-probability of the sampled token.
        """
        logws = await self.potential.logw_next(context)
        if self.prune_logws:
            logws = self._prune_logws(logws)

        logZ = logsumexp(logws.weights)
        logps = logws.weights - logZ
        toks = logws.decode

        # We cache successful calls, as algorithms may want to see the
        # same successful token more than once (currently just geometric_awrs)
        cache = {}

        async def accept(tok):
            try:
                return cache[tok]
            except KeyError:
                pass
            result = await self._accept(context, tok, verbosity)
            if result:
                cache[tok] = result
            return result

        if not self.proper_weights:
            return await improper_sample(
                logps=logps,
                toks=toks,
                accept=accept,
                rng=self.rng,
                max_rejects=self.max_rejects,
            )
        # We pick which algorithm to use based on parameters and the
        # shape of the distribution, as this lets us pick the most
        # effective option.
        elif (
            # If max_accepts is large then recursive_awrs (which
            # does not currently support this parameter) isn't very
            # useful, because the recursive step means that you never
            # revisit the same value, so will often throw away most
            # of the accepted mass if you were to continue. Also
            # this parameter is only really relevant if you want to
            # lower the variance, and geometric_awrs is lower variance.
            self.max_accepts > 2
            or
            # If the distribution is strongly peaked around a single value
            # then geometric_awrs will be more efficient. See below
            # for specific derivation.
            logps.max() >= GEOMETRIC_THRESHOLD
        ):
            tok, w, _ = await geometric_awrs(
                logps=logps,
                toks=toks,
                accept=accept,
                rng=self.rng,
                max_rejects=self.max_rejects,
                max_accepts=self.max_accepts,
            )
            return tok, w + logZ, np.nan
        else:
            tok, w, _ = await recursive_awrs(
                logps=logps,
                toks=toks,
                accept=accept,
                rng=self.rng,
                max_rejects=self.max_rejects,
            )
            return tok, w + logZ, np.nan

sample(context, verbosity=0) async

Sample a token and weight that are properly weighted with respect to the target potential's logw_next method via adaptive weighted rejection sampling.

The returned weight corresponds to the log normalizing constant of \(\textsf{target.logw_next}(x_n | x_1, \ldots, x_{n-1})\).

Returns:

Type Description
(token, weight, nan)

A tuple containing the sampled token, weight, and a dummy value for the log-probability of the sampled token.

Source code in genlm/control/sampler/token.py
async def sample(self, context, verbosity=0):
    """Sample a token and weight that are properly weighted with respect to the target potential's `logw_next` method via adaptive weighted rejection sampling.

    The returned weight corresponds to the log normalizing constant of $\\textsf{target.logw_next}(x_n | x_1, \\ldots, x_{n-1})$.

    Returns:
        (token, weight, np.nan): A tuple containing the sampled token, weight, and a dummy value for the log-probability of the sampled token.
    """
    logws = await self.potential.logw_next(context)
    if self.prune_logws:
        logws = self._prune_logws(logws)

    logZ = logsumexp(logws.weights)
    logps = logws.weights - logZ
    toks = logws.decode

    # We cache successful calls, as algorithms may want to see the
    # same successful token more than once (currently just geometric_awrs)
    cache = {}

    async def accept(tok):
        try:
            return cache[tok]
        except KeyError:
            pass
        result = await self._accept(context, tok, verbosity)
        if result:
            cache[tok] = result
        return result

    if not self.proper_weights:
        return await improper_sample(
            logps=logps,
            toks=toks,
            accept=accept,
            rng=self.rng,
            max_rejects=self.max_rejects,
        )
    # We pick which algorithm to use based on parameters and the
    # shape of the distribution, as this lets us pick the most
    # effective option.
    elif (
        # If max_accepts is large then recursive_awrs (which
        # does not currently support this parameter) isn't very
        # useful, because the recursive step means that you never
        # revisit the same value, so will often throw away most
        # of the accepted mass if you were to continue. Also
        # this parameter is only really relevant if you want to
        # lower the variance, and geometric_awrs is lower variance.
        self.max_accepts > 2
        or
        # If the distribution is strongly peaked around a single value
        # then geometric_awrs will be more efficient. See below
        # for specific derivation.
        logps.max() >= GEOMETRIC_THRESHOLD
    ):
        tok, w, _ = await geometric_awrs(
            logps=logps,
            toks=toks,
            accept=accept,
            rng=self.rng,
            max_rejects=self.max_rejects,
            max_accepts=self.max_accepts,
        )
        return tok, w + logZ, np.nan
    else:
        tok, w, _ = await recursive_awrs(
            logps=logps,
            toks=toks,
            accept=accept,
            rng=self.rng,
            max_rejects=self.max_rejects,
        )
        return tok, w + logZ, np.nan

TokenSampler

Bases: SubModel

Base class for sampling a token from a potential's vocabulary.

TokenSamplers generate properly weighted samples with respect to a target potential.

Given a context of tokens \(x_1, \ldots, x_{n-1}\) in the target potential's vocabulary, a TokenSampler samples a token \(x_n \in \textsf{target.vocab_eos}\) and weight \(w\).

The sampled token and weight are properly weighted with respect to $$ \textsf{target.logw_next}(x_n | x_1, \ldots, x_{n-1}) $$

Parameters:

Name Type Description Default
target Potential

The potential that samples are properly weighted with respect to.

required
Source code in genlm/control/sampler/token.py
class TokenSampler(SubModel):
    """Base class for sampling a token from a potential's vocabulary.

    `TokenSampler`s generate properly weighted samples with respect to a `target` potential.

    Given a context of tokens $x_1, \\ldots, x_{n-1}$ in the target potential's vocabulary,
    a `TokenSampler` samples a token $x_n \\in \\textsf{target.vocab_eos}$ and weight $w$.

    The sampled token and weight are properly weighted with respect to
    $$
    \\textsf{target.logw_next}(x_n | x_1, \\ldots, x_{n-1})
    $$

    Args:
        target (Potential): The potential that samples are properly weighted with respect to.
    """

    def __init__(self, target):
        super().__init__()
        self.target = target
        self.token_type = self.target.token_type

    async def start_weight(self):
        """Compute the weight of the empty sequence under the target potential."""
        return await self.target.prefix([])

    async def forward(self):
        parent = self.parent  # For some reason, need to hold onto this reference.
        token, logw, logp = await self.sample(parent.token_ctx)
        parent.score(logw)
        parent.logp += logp
        return token

    async def sample(self, context, draw):
        """Sample a token and weight from the `target`potential's vocabulary.

        Args:
            context (list[int]): A sequence of tokens in the `target` potential's vocabulary.
            draw (callable): A callable that draws a sample from a distribution.

        Returns:
            (token, weight, logp): A tuple containing the sampled token, weight, and log-probability of the sampled token.
        """
        raise NotImplementedError(
            "Subclasses must implement sample method"
        )  # pragma: no cover

    async def cleanup(self):
        pass  # pragma: no cover

    async def smc(self, n_particles, ess_threshold, max_tokens, critic=None, **kwargs):
        """Generate sequences using sequential Monte Carlo (SMC) inference with this token sampler and an optional critic.

        This method is a convenience wrapper around [`SMC`][genlm.control.sampler.sequence.SMC].
        See [`SMC`][genlm.control.sampler.sequence.SMC] for more details on the generation process.

        Args:
            n_particles (int): The number of particles to use in the SMC algorithm.
            ess_threshold (float): The threshold for the effective sample size (ESS).
            max_tokens (int): The maximum number of tokens to generate.
            critic (Potential, optional): A potential function that guides the generation process
                by scoring candidate sequences. Must have the same token type as the token sampler.
            **kwargs (dict): Additional keyword arguments to pass to `SMC`'s `__call__` method.
        """
        from genlm.control.sampler.sequence import SMC

        return await SMC(self, critic)(
            n_particles=n_particles,
            ess_threshold=ess_threshold,
            max_tokens=max_tokens,
            **kwargs,
        )

start_weight() async

Compute the weight of the empty sequence under the target potential.

Source code in genlm/control/sampler/token.py
async def start_weight(self):
    """Compute the weight of the empty sequence under the target potential."""
    return await self.target.prefix([])

sample(context, draw) async

Sample a token and weight from the targetpotential's vocabulary.

Parameters:

Name Type Description Default
context list[int]

A sequence of tokens in the target potential's vocabulary.

required
draw callable

A callable that draws a sample from a distribution.

required

Returns:

Type Description
(token, weight, logp)

A tuple containing the sampled token, weight, and log-probability of the sampled token.

Source code in genlm/control/sampler/token.py
async def sample(self, context, draw):
    """Sample a token and weight from the `target`potential's vocabulary.

    Args:
        context (list[int]): A sequence of tokens in the `target` potential's vocabulary.
        draw (callable): A callable that draws a sample from a distribution.

    Returns:
        (token, weight, logp): A tuple containing the sampled token, weight, and log-probability of the sampled token.
    """
    raise NotImplementedError(
        "Subclasses must implement sample method"
    )  # pragma: no cover

smc(n_particles, ess_threshold, max_tokens, critic=None, **kwargs) async

Generate sequences using sequential Monte Carlo (SMC) inference with this token sampler and an optional critic.

This method is a convenience wrapper around SMC. See SMC for more details on the generation process.

Parameters:

Name Type Description Default
n_particles int

The number of particles to use in the SMC algorithm.

required
ess_threshold float

The threshold for the effective sample size (ESS).

required
max_tokens int

The maximum number of tokens to generate.

required
critic Potential

A potential function that guides the generation process by scoring candidate sequences. Must have the same token type as the token sampler.

None
**kwargs dict

Additional keyword arguments to pass to SMC's __call__ method.

{}
Source code in genlm/control/sampler/token.py
async def smc(self, n_particles, ess_threshold, max_tokens, critic=None, **kwargs):
    """Generate sequences using sequential Monte Carlo (SMC) inference with this token sampler and an optional critic.

    This method is a convenience wrapper around [`SMC`][genlm.control.sampler.sequence.SMC].
    See [`SMC`][genlm.control.sampler.sequence.SMC] for more details on the generation process.

    Args:
        n_particles (int): The number of particles to use in the SMC algorithm.
        ess_threshold (float): The threshold for the effective sample size (ESS).
        max_tokens (int): The maximum number of tokens to generate.
        critic (Potential, optional): A potential function that guides the generation process
            by scoring candidate sequences. Must have the same token type as the token sampler.
        **kwargs (dict): Additional keyword arguments to pass to `SMC`'s `__call__` method.
    """
    from genlm.control.sampler.sequence import SMC

    return await SMC(self, critic)(
        n_particles=n_particles,
        ess_threshold=ess_threshold,
        max_tokens=max_tokens,
        **kwargs,
    )

EagerSetSampler

Bases: TrieSetSampler

A trie-based set sampler that implements an eager sampling strategy for generating a set of tokens.

An EagerSetSampler samples tokens by incrementally sampling items from the item-wise product of the iter_potential and item_potential. The sampled set is the set of sequences of items that correspond to valid tokens in iter_potential's vocabulary.

Source code in genlm/control/sampler/set.py
class EagerSetSampler(TrieSetSampler):
    """
    A trie-based set sampler that implements an eager sampling strategy
    for generating a set of tokens.

    An `EagerSetSampler` samples tokens by incrementally sampling items from the item-wise product of the `iter_potential` and `item_potential`.
    The sampled set is the set of sequences of items that correspond to valid tokens in `iter_potential`'s vocabulary.
    """

    async def sample_set(self, context, draw=None):
        """
        Sample a set of tokens given a context.

        Args:
            context (list): A sequence of tokens in the `iter_potential`'s vocabulary.

        Returns:
            (LazyWeights, float): A weighted set of tokens and the log-probability of the sampled set.
        """
        if draw is None:
            draw = sample_dict
        iter_logws = await self.iter_potential.logw_next(context)
        item_ws = await self.trie_executor.weight_sum(iter_logws.exp().weights)

        logws = self.target.alloc_logws()
        curr = self.trie.root
        coerced_ctx = self.f(context)
        subtokens = []
        logp, logw = 0, 0

        while True:
            children = self.trie.children[curr]
            item_w_curr = item_ws[curr]
            item_ws1 = Float.chart(
                {a: item_ws[c] / item_w_curr for a, c in children.items()}
            )

            if None in item_ws1:
                leaf = children[None]
                token = self.trie.leaf2word[leaf]
                token_id = self.leaf_to_token_id[leaf]
                logws[token_id] = iter_logws[token] + logw - logp

            item_logws2 = await self.item_potential.logw_next(coerced_ctx + subtokens)
            item_ws2 = item_logws2.exp().materialize()
            w_next = (item_ws1 * item_ws2).trim()

            if not w_next:
                break

            ps = w_next.normalize()
            b = draw(ps)
            logp += np.log(ps[b])
            logw += item_logws2[b]

            if b == self.target.eos:
                assert not subtokens, "subtokens should be empty at EOS."
                logws[-1] = iter_logws[self.target.eos] + logw - logp
                break

            subtokens.append(b)
            curr = children[b]

        return self.target.make_lazy_weights(logws), logp

sample_set(context, draw=None) async

Sample a set of tokens given a context.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the iter_potential's vocabulary.

required

Returns:

Type Description
(LazyWeights, float)

A weighted set of tokens and the log-probability of the sampled set.

Source code in genlm/control/sampler/set.py
async def sample_set(self, context, draw=None):
    """
    Sample a set of tokens given a context.

    Args:
        context (list): A sequence of tokens in the `iter_potential`'s vocabulary.

    Returns:
        (LazyWeights, float): A weighted set of tokens and the log-probability of the sampled set.
    """
    if draw is None:
        draw = sample_dict
    iter_logws = await self.iter_potential.logw_next(context)
    item_ws = await self.trie_executor.weight_sum(iter_logws.exp().weights)

    logws = self.target.alloc_logws()
    curr = self.trie.root
    coerced_ctx = self.f(context)
    subtokens = []
    logp, logw = 0, 0

    while True:
        children = self.trie.children[curr]
        item_w_curr = item_ws[curr]
        item_ws1 = Float.chart(
            {a: item_ws[c] / item_w_curr for a, c in children.items()}
        )

        if None in item_ws1:
            leaf = children[None]
            token = self.trie.leaf2word[leaf]
            token_id = self.leaf_to_token_id[leaf]
            logws[token_id] = iter_logws[token] + logw - logp

        item_logws2 = await self.item_potential.logw_next(coerced_ctx + subtokens)
        item_ws2 = item_logws2.exp().materialize()
        w_next = (item_ws1 * item_ws2).trim()

        if not w_next:
            break

        ps = w_next.normalize()
        b = draw(ps)
        logp += np.log(ps[b])
        logw += item_logws2[b]

        if b == self.target.eos:
            assert not subtokens, "subtokens should be empty at EOS."
            logws[-1] = iter_logws[self.target.eos] + logw - logp
            break

        subtokens.append(b)
        curr = children[b]

    return self.target.make_lazy_weights(logws), logp

TopKSetSampler

Bases: TrieSetSampler

A trie-based set sampler that lazily enumerates the top K tokens by weight in the target, and samples an additional "wildcard" token to ensure absolute continuity.

Warning

This sampler is not guaranteed to be correct if the item_potential's prefix weights do not monotonically decrease with the length of the context. That is, \(\textsf{item_potential.prefix}(x) \leq \textsf{item_potential.prefix}(xy)\) for all sequences of items \(x, y\).

Source code in genlm/control/sampler/set.py
class TopKSetSampler(TrieSetSampler):
    """
    A trie-based set sampler that lazily enumerates the top K tokens by weight in the target,
    and samples an additional "wildcard" token to ensure absolute continuity.

    Warning:
        This sampler is not guaranteed to be correct if the `item_potential`'s
        prefix weights do not monotonically decrease with the length of the context.
        That is, $\\textsf{item_potential.prefix}(x) \\leq \\textsf{item_potential.prefix}(xy)$ for all sequences of items $x, y$.
    """

    def __init__(self, iter_potential, item_potential, K):
        """
        Initialize the TopKSetSampler.

        Args:
            iter_potential (Potential): The potential defined over a vocabulary of iterables.
            item_potential (Potential): The potential defined over a vocabulary of items.
            K (int|None): The number of top tokens to enumerate. If None, all tokens are enumerated.
        """
        if K is not None and K <= 0:
            raise ValueError("K must be greater than 0 or None")
        super().__init__(iter_potential, item_potential)
        self.K = K

    async def sample_set(self, context, draw=None):
        """
        Sample a set of tokens given a context.

        Args:
            context (list): A sequence of tokens in the `iter_potential`'s vocabulary.

        Returns:
            (LazyWeights, float): A weighted set of tokens and the log-probability of the sampled set.
        """
        if draw is None:
            draw = sample_dict
        iter_logws = await self.iter_potential.logw_next(context)
        max_logws = await self.trie_executor.weight_max(iter_logws.weights)

        k = 0
        logws = self.target.alloc_logws()
        sampled = self.target.alloc_logws(default=False)

        async for token_id, logw in self._lazy_enum(context, max_logws):
            logws[token_id] = logw
            sampled[token_id] = True
            k += 1
            if self.K is not None and k >= self.K:
                break

        logp_wc = 0
        if self.K is not None and k == self.K:
            # Get the distribution over wildcard tokens
            iter_ws = iter_logws.exp()
            W_wc = Float.chart(
                {
                    token_id: iter_ws[token]
                    for token_id, token in enumerate(self.target.vocab_eos)
                    if not sampled[token_id]
                }
            )

            # if W_wc is non-empty, sample a wildcard token to ensure absolute continuity
            if W_wc:
                P_wc = W_wc.normalize()
                wc_id = draw(P_wc)
                logp_wc = np.log(P_wc[wc_id])
                wc = self.target.vocab_eos[wc_id]
                item_ctx = self.f(context)
                prefix_w = await self.item_potential.prefix(item_ctx)
                if wc == self.target.eos:
                    w_guide_wc = await self.item_potential.complete(item_ctx) - prefix_w
                else:
                    w_guide_wc = (
                        await self.item_potential.prefix(self.f(context + [wc]))
                        - prefix_w
                    )
                logws[wc_id] = np.log(W_wc[wc_id]) + w_guide_wc - logp_wc

        return self.target.make_lazy_weights(logws), logp_wc

    async def _lazy_enum(self, context, max_logws):
        agenda = LocatorMaxHeap()

        W = Float.chart()

        # initial conditions
        (token, node) = ((), self.trie.root)
        agenda[token, node, False] = max_logws[node]
        W[node] = 0

        children = self.trie.children
        coerced_ctx = self.f(context)

        curr_priority = float("inf")
        prev_best = float("inf")
        while agenda:
            (token, node, done), score = agenda.popitem()

            assert score <= curr_priority, (
                "Monotonicity assumption violated. "
                "`item_potential` prefix weight must be monotonically decreasing."
            )
            curr_priority = score

            # terminal state
            if done:
                value = W[node] + max_logws[node]
                assert prev_best >= value
                prev_best = value
                yield (self.leaf_to_token_id[node], value)
                continue

            logws = None
            for x, y in children[node].items():
                if x is None:
                    W_y = W[node]
                    W[y] = W_y
                    agenda[token, y, True] = W_y + max_logws[y]
                else:
                    if logws is None:
                        logws = await self.item_potential.logw_next(
                            coerced_ctx + list(token)
                        )
                    W_y = W[node] + logws[x]
                    if W_y == float("-inf"):
                        continue
                    W[y] = W_y
                    agenda[(*token, x), y, False] = W_y + max_logws[y]

__init__(iter_potential, item_potential, K)

Initialize the TopKSetSampler.

Parameters:

Name Type Description Default
iter_potential Potential

The potential defined over a vocabulary of iterables.

required
item_potential Potential

The potential defined over a vocabulary of items.

required
K int | None

The number of top tokens to enumerate. If None, all tokens are enumerated.

required
Source code in genlm/control/sampler/set.py
def __init__(self, iter_potential, item_potential, K):
    """
    Initialize the TopKSetSampler.

    Args:
        iter_potential (Potential): The potential defined over a vocabulary of iterables.
        item_potential (Potential): The potential defined over a vocabulary of items.
        K (int|None): The number of top tokens to enumerate. If None, all tokens are enumerated.
    """
    if K is not None and K <= 0:
        raise ValueError("K must be greater than 0 or None")
    super().__init__(iter_potential, item_potential)
    self.K = K

sample_set(context, draw=None) async

Sample a set of tokens given a context.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the iter_potential's vocabulary.

required

Returns:

Type Description
(LazyWeights, float)

A weighted set of tokens and the log-probability of the sampled set.

Source code in genlm/control/sampler/set.py
async def sample_set(self, context, draw=None):
    """
    Sample a set of tokens given a context.

    Args:
        context (list): A sequence of tokens in the `iter_potential`'s vocabulary.

    Returns:
        (LazyWeights, float): A weighted set of tokens and the log-probability of the sampled set.
    """
    if draw is None:
        draw = sample_dict
    iter_logws = await self.iter_potential.logw_next(context)
    max_logws = await self.trie_executor.weight_max(iter_logws.weights)

    k = 0
    logws = self.target.alloc_logws()
    sampled = self.target.alloc_logws(default=False)

    async for token_id, logw in self._lazy_enum(context, max_logws):
        logws[token_id] = logw
        sampled[token_id] = True
        k += 1
        if self.K is not None and k >= self.K:
            break

    logp_wc = 0
    if self.K is not None and k == self.K:
        # Get the distribution over wildcard tokens
        iter_ws = iter_logws.exp()
        W_wc = Float.chart(
            {
                token_id: iter_ws[token]
                for token_id, token in enumerate(self.target.vocab_eos)
                if not sampled[token_id]
            }
        )

        # if W_wc is non-empty, sample a wildcard token to ensure absolute continuity
        if W_wc:
            P_wc = W_wc.normalize()
            wc_id = draw(P_wc)
            logp_wc = np.log(P_wc[wc_id])
            wc = self.target.vocab_eos[wc_id]
            item_ctx = self.f(context)
            prefix_w = await self.item_potential.prefix(item_ctx)
            if wc == self.target.eos:
                w_guide_wc = await self.item_potential.complete(item_ctx) - prefix_w
            else:
                w_guide_wc = (
                    await self.item_potential.prefix(self.f(context + [wc]))
                    - prefix_w
                )
            logws[wc_id] = np.log(W_wc[wc_id]) + w_guide_wc - logp_wc

    return self.target.make_lazy_weights(logws), logp_wc

SMC

This class implements sequential Monte Carlo (SMC) inference for controlled text generation. The generation process works as follows:

  1. Token Sampling: At each step, the unit_sampler is used to extend each particle (candidate sequence) by sampling a new token. This grows all sequences by one token at a time. The sampler also outputs an importance weight with each extension to correct for the myopic nature of token-by-token sampling.

  2. Critic Evaluation: If a critic is provided, it scores the updated sequences (via it's score method), reweighting the particles based on how well they satisfy the constraints encoded by the critic.

  3. Resampling: When the effective sample size (ESS) falls below the threshold, particles are resampled according to their weights. This helps focus computation on more promising sequences.

  4. Termination: The process continues until either:

    • All sequences reach an end-of-sequence (EOS) token

    • The maximum token length is reached

If a critic is provided, the resulting sequences are properly weighted with respect to the product of the unit sampler's target potential and the critic potential (unit_sampler.target * critic). If a critic is not provided, the resulting sequences are weighted with respect to the unit sampler's target potential.

Parameters:

Name Type Description Default
unit_sampler TokenSampler

The sampler that generates tokens.

required
critic Potential

A potential function that guides the generation process by scoring candidate sequences. Must have the same token type as the unit_sampler.

None

Raises:

Type Description
ValueError

If unit_sampler is not a TokenSampler, if critic is not a Potential, or if the token types of unit_sampler and critic don't match.

Source code in genlm/control/sampler/sequence.py
class SMC:
    """This class implements sequential Monte Carlo (SMC) inference for controlled text generation.
    The generation process works as follows:

    1. Token Sampling: At each step, the `unit_sampler` is used to extend each particle (candidate sequence)
       by sampling a new token. This grows all sequences by one token at a time. The sampler also outputs
       an importance weight with each extension to correct for the myopic nature of token-by-token sampling.

    2. Critic Evaluation: If a `critic` is provided, it scores the updated sequences (via it's `score` method),
       reweighting the particles based on how well they satisfy the constraints encoded by the critic.

    3. Resampling: When the effective sample size (ESS) falls below the threshold,
       particles are resampled according to their weights. This helps focus computation
       on more promising sequences.

    4. Termination: The process continues until either:\n
        - All sequences reach an end-of-sequence (EOS) token\n
        - The maximum token length is reached

    If a critic is provided, the resulting sequences are properly weighted with respect to the product of the unit sampler's
    target potential and the critic potential (`unit_sampler.target * critic`). If a critic is not provided,
    the resulting sequences are weighted with respect to the unit sampler's target potential.

    Args:
        unit_sampler (TokenSampler): The sampler that generates tokens.
        critic (Potential, optional): A potential function that guides the generation process
            by scoring candidate sequences. Must have the same token type as the unit_sampler.

    Raises:
        ValueError: If unit_sampler is not a TokenSampler, if critic is not a Potential,
            or if the token types of unit_sampler and critic don't match.
    """

    def __init__(self, unit_sampler, critic=None):
        if not isinstance(unit_sampler, TokenSampler):
            raise ValueError("`unit_sampler` must be a TokenSampler")

        if critic:
            if not isinstance(critic, Potential):
                raise ValueError("`critic` must be a Potential")
            if not unit_sampler.token_type == critic.token_type:
                raise ValueError(
                    "`critic` must have the same token type as the `unit_sampler`. "
                    f"Got {unit_sampler.token_type} and {critic.token_type}."
                    + (
                        "\nMaybe you forgot to coerce the critic to the token type of the unit sampler? See `Coerce`."
                        if unit_sampler.token_type.is_iterable_of(critic.token_type)
                        else ""
                    )
                )

        self.unit_sampler = unit_sampler
        self.critic = critic

    async def __call__(
        self,
        n_particles,
        ess_threshold,
        max_tokens,
        verbosity=0,
        json_path=None,
        **kwargs,
    ):
        """Generate sequences using sequential Monte Carlo inference.

        Args:
            n_particles (int): Number of particles (candidate sequences) to maintain during
                generation. Higher values provide better exploration but require more
                computation.
            ess_threshold (float): Effective sample size threshold for resampling,
                expressed as a fraction of the number of particles. When ESS falls below
                this value, particles are resampled according to their weights. Should be between 0 and 1.
                Higher values lead to more frequent resampling. Note that when ess_threshold = 0,
                the critic is only applied at the end of the generation (if it is provided).
            max_tokens (int): Maximum number of tokens to generate per sequence. Generation
                may terminate earlier if all sequences reach an EOS token.
            verbosity (int, optional): Verbosity level for the SMC algorithm. 0 is silent, 1 prints the
                particles at each step. Default is 0.
            json_path (str, optional): JSON file path for saving a record of the inference run.
                This can be used in conjunction with the `InferenceVisualizer` to visualize the inference run.
            **kwargs (dict): Additional keyword arguments to pass to the SMC algorithm.
                See the `llamppl.inference.smc_standard` documentation for more details.

        Returns:
            (Sequences): A container holding the generated sequences, their importance weights, and
                other metadata from the generation process.
        """
        model = SequenceModel(
            unit_sampler=self.unit_sampler,
            critic=self.critic,
            max_tokens=max_tokens,
            verbosity=verbosity,
            twist_with_critic=ess_threshold > 0,
        )

        particles = await smc_standard(
            model=model,
            n_particles=n_particles,
            ess_threshold=ess_threshold,
            json_file=json_path,
            **kwargs,
        )

        return Sequences(*_unpack_particles(particles))

    async def cleanup(self):
        """Clean up resources used by the inference engine.

        This method should be called when the InferenceEngine is no longer needed.

        Example:
            ```python
            sampler = SequenceSampler(unit_sampler, critic)
            try:
                sequences = await sampler(n_particles=10, ess_threshold=0.5, max_tokens=20)
            finally:
                await sampler.cleanup()
            ```
        """
        await self.unit_sampler.cleanup()
        if self.critic:
            await self.critic.cleanup()

__call__(n_particles, ess_threshold, max_tokens, verbosity=0, json_path=None, **kwargs) async

Generate sequences using sequential Monte Carlo inference.

Parameters:

Name Type Description Default
n_particles int

Number of particles (candidate sequences) to maintain during generation. Higher values provide better exploration but require more computation.

required
ess_threshold float

Effective sample size threshold for resampling, expressed as a fraction of the number of particles. When ESS falls below this value, particles are resampled according to their weights. Should be between 0 and 1. Higher values lead to more frequent resampling. Note that when ess_threshold = 0, the critic is only applied at the end of the generation (if it is provided).

required
max_tokens int

Maximum number of tokens to generate per sequence. Generation may terminate earlier if all sequences reach an EOS token.

required
verbosity int

Verbosity level for the SMC algorithm. 0 is silent, 1 prints the particles at each step. Default is 0.

0
json_path str

JSON file path for saving a record of the inference run. This can be used in conjunction with the InferenceVisualizer to visualize the inference run.

None
**kwargs dict

Additional keyword arguments to pass to the SMC algorithm. See the llamppl.inference.smc_standard documentation for more details.

{}

Returns:

Type Description
Sequences

A container holding the generated sequences, their importance weights, and other metadata from the generation process.

Source code in genlm/control/sampler/sequence.py
async def __call__(
    self,
    n_particles,
    ess_threshold,
    max_tokens,
    verbosity=0,
    json_path=None,
    **kwargs,
):
    """Generate sequences using sequential Monte Carlo inference.

    Args:
        n_particles (int): Number of particles (candidate sequences) to maintain during
            generation. Higher values provide better exploration but require more
            computation.
        ess_threshold (float): Effective sample size threshold for resampling,
            expressed as a fraction of the number of particles. When ESS falls below
            this value, particles are resampled according to their weights. Should be between 0 and 1.
            Higher values lead to more frequent resampling. Note that when ess_threshold = 0,
            the critic is only applied at the end of the generation (if it is provided).
        max_tokens (int): Maximum number of tokens to generate per sequence. Generation
            may terminate earlier if all sequences reach an EOS token.
        verbosity (int, optional): Verbosity level for the SMC algorithm. 0 is silent, 1 prints the
            particles at each step. Default is 0.
        json_path (str, optional): JSON file path for saving a record of the inference run.
            This can be used in conjunction with the `InferenceVisualizer` to visualize the inference run.
        **kwargs (dict): Additional keyword arguments to pass to the SMC algorithm.
            See the `llamppl.inference.smc_standard` documentation for more details.

    Returns:
        (Sequences): A container holding the generated sequences, their importance weights, and
            other metadata from the generation process.
    """
    model = SequenceModel(
        unit_sampler=self.unit_sampler,
        critic=self.critic,
        max_tokens=max_tokens,
        verbosity=verbosity,
        twist_with_critic=ess_threshold > 0,
    )

    particles = await smc_standard(
        model=model,
        n_particles=n_particles,
        ess_threshold=ess_threshold,
        json_file=json_path,
        **kwargs,
    )

    return Sequences(*_unpack_particles(particles))

cleanup() async

Clean up resources used by the inference engine.

This method should be called when the InferenceEngine is no longer needed.

Example
sampler = SequenceSampler(unit_sampler, critic)
try:
    sequences = await sampler(n_particles=10, ess_threshold=0.5, max_tokens=20)
finally:
    await sampler.cleanup()
Source code in genlm/control/sampler/sequence.py
async def cleanup(self):
    """Clean up resources used by the inference engine.

    This method should be called when the InferenceEngine is no longer needed.

    Example:
        ```python
        sampler = SequenceSampler(unit_sampler, critic)
        try:
            sequences = await sampler(n_particles=10, ess_threshold=0.5, max_tokens=20)
        finally:
            await sampler.cleanup()
        ```
    """
    await self.unit_sampler.cleanup()
    if self.critic:
        await self.critic.cleanup()

MultiTokenUnitSampler

Bases: TokenSampler

Sampler that groups multiple tokens into larger units.

This sampler enables generation at a coarser granularity than individual tokens by repeatedly sampling tokens until a boundary condition is met. Common use cases:

  • Word-level sampling: Group tokens until a word boundary (e.g., whitespace)
  • Sentence-level sampling: Group tokens until punctuation marks
  • Grammar-based units: Group tokens completing a grammar terminal

The sampler delegates to a subunit_sampler (typically a token-level sampler) and accumulates samples until the boundary_predicate signals completion. The final weight is the product of weights from each individual token sample. This ensures that sampling remains properly weighted w.r.t. the target potential.

Weight calculation: If sampling a unit requires \(n\) token samples with weights \(w_1, w_2, \ldots, w_n\), the unit weight is \(w = \prod_{i=1}^{n} w_i\) (or \(\log w = \sum_{i=1}^{n} \log w_i\) in log-space).

Parameters:

Name Type Description Default
subunit_sampler TokenSampler

Sampler for subunits \(s \in \mathcal{B}\)

required
boundary_predicate BoundaryPredicate

Determines when a sequence of tokens forms a complete unit. Also controls how to finalize the unit via finalize_unit().

required
max_subunits_per_unit int

Safety timeout to prevent non-termination. Default: 100.

100
Example

Sample word-level units (multi-token)

llm = PromptedLLM.from_name("gpt2") subunit_sampler = DirectTokenSampler(llm)

Word boundaries at whitespace

boundary = TokenSetBoundary({b" ", b"\n"}) unit_sampler = MultiTokenUnitSampler( ... subunit_sampler=subunit_sampler, ... boundary_predicate=boundary, ... max_subunits_per_unit=50 ... )

Units will be words WITH trailing space: [b"hello", b" "]

Source code in genlm/control/sampler/unit.py
class MultiTokenUnitSampler(TokenSampler):
    """Sampler that groups multiple tokens into larger units.

    This sampler enables generation at a coarser granularity than individual tokens
    by repeatedly sampling tokens until a boundary condition is met. Common use cases:

    - **Word-level sampling**: Group tokens until a word boundary (e.g., whitespace)
    - **Sentence-level sampling**: Group tokens until punctuation marks
    - **Grammar-based units**: Group tokens completing a grammar terminal

    The sampler delegates to a `subunit_sampler` (typically a token-level sampler)
    and accumulates samples until the `boundary_predicate` signals completion. The final
    weight is the product of weights from each individual token sample. This ensures that
    sampling remains properly weighted w.r.t. the target potential.

    **Weight calculation**: If sampling a unit requires $n$ token samples with weights
    $w_1, w_2, \\ldots, w_n$, the unit weight is $w = \\prod_{i=1}^{n} w_i$ (or
    $\\log w = \\sum_{i=1}^{n} \\log w_i$ in log-space).

    Args:
        subunit_sampler (TokenSampler): Sampler for subunits $s \\in \\mathcal{B}$
        boundary_predicate (BoundaryPredicate): Determines when a sequence of tokens forms
            a complete unit. Also controls how to finalize the unit via `finalize_unit()`.
        max_subunits_per_unit (int): Safety timeout to prevent non-termination. Default: 100.

    Example:
        >>> # Sample word-level units (multi-token)
        >>> llm = PromptedLLM.from_name("gpt2")
        >>> subunit_sampler = DirectTokenSampler(llm)
        >>>
        >>> # Word boundaries at whitespace
        >>> boundary = TokenSetBoundary({b" ", b"\\n"})
        >>> unit_sampler = MultiTokenUnitSampler(
        ...     subunit_sampler=subunit_sampler,
        ...     boundary_predicate=boundary,
        ...     max_subunits_per_unit=50
        ... )
        >>> # Units will be words WITH trailing space: [b"hello", b" "]
    """

    def __init__(
        self,
        subunit_sampler,
        boundary_predicate,
        max_subunits_per_unit=100,
    ):
        if not isinstance(subunit_sampler, TokenSampler):
            raise TypeError(
                f"subunit_sampler must be a TokenSampler, got {type(subunit_sampler)}"
            )

        # Initialized with subunit sampler's target
        # We may want to add support for different samplers in the future
        super().__init__(target=subunit_sampler.target)

        self.subunit_sampler = subunit_sampler
        self.boundary_predicate = boundary_predicate
        self.max_subunits_per_unit = max_subunits_per_unit

    async def start_weight(self):
        """Return $\\overrightarrow{\\psi}(\\epsilon)$ (prefix weight of empty sequence)."""
        return await self.subunit_sampler.start_weight()

    async def forward(self):
        """Called by LLaMPPL Model.call() to sample one multi-token unit.

        Called by SequenceModel.step() when it calls self.call(unit_sampler).
        """
        parent = self.parent

        # Flatten parent.token_ctx before passing to sample
        # This ensures sample() always works with a flat list
        flat_context = flatten_units(parent.token_ctx)

        # Sample multi-token unit, passing both flat context and structured unit context
        unit, logw, logp = await self.sample(
            flat_context, unit_context=parent.token_ctx, draw=None
        )

        # Update parent's weight and logp
        parent.score(logw)
        parent.logp += logp

        # If the unit ends with EOS, return EOS directly so SequenceModel can detect completion
        # SequenceModel.step() checks `token_ctx[-1] is EOS` to finish generation
        if unit and unit[-1] is EOS:
            # Keep the unit content before EOS in token_ctx, then return EOS separately
            if len(unit) > 1:
                parent.token_ctx.append(unit[:-1])  # Add unit without EOS
            return EOS  # Return EOS directly for SequenceModel to detect

        return unit

    async def sample(self, flat_token_context, unit_context=None, draw=None):
        """Sample a multi-token unit by running sequence sampling for $\\varphi_{\\bm{x}}$.
        SIS for the localized potential:

        1. Repeatedly sample $(s_i, w_i) \\sim q_{\\text{sub}}(\\cdot \\mid \\bm{s}_{<i})$ until boundary
        2. Accumulate weights: $w = \\overrightarrow{\\psi}_{\\bm{x}}(\\epsilon) \\prod_i w_i$
        3. Return $(\\bm{s}, w)$ where $\\bm{s} \\in \\mathcal{B}^*$ forms unit $x \\in \\mathcal{A}$

        Args:
            flat_token_context (list): Flat sequence of all previously sampled tokens.
                This is pre-flattened by forward() to ensure compatibility with potentials.
            unit_context (list, optional): Structured sequence of previously sampled units.
                Used by boundary predicates that need context. Defaults to [].
            draw (callable, optional): Sampling function passed to subunit_sampler

        Returns:
            (unit, weight, logp):
                - unit: List of subunits $[s_1, \\ldots, s_n]$ forming $x \\in \\mathcal{A}$
                - weight: Importance weight $w$ such that $(\\text{unit}, w)$ is properly
                    weighted w.r.t. $\\psi(x \\mid \\bm{x})$
                - logp: Sum of log-probabilities of sampling choices
        """
        if unit_context is None:
            unit_context = []

        subunit_buffer = []
        current_context = list(flat_token_context)

        # Accumulate weights
        cumulative_logw = 0.0
        cumulative_logp = 0.0

        # Sequential sampling until EOT
        for _ in range(self.max_subunits_per_unit):
            # Sample next subunit $(s_i, w_i) \\sim q_{\\text{sub}}(\\cdot \\mid \\bm{s}_{<i})$
            try:
                subunit, logw_i, logp_i = await self.subunit_sampler.sample(
                    current_context, draw
                )
            except (RuntimeError, OSError, TimeoutError):
                # Expected failures (network, timeout, system errors)
                # Return current buffer with -inf weight to discard this sample
                return subunit_buffer, float("-inf"), cumulative_logp

            # Accumulate weight and logp
            cumulative_logw += logw_i
            cumulative_logp += logp_i

            # Add to both buffer and context
            subunit_buffer.append(subunit)
            current_context.append(subunit)

            # Check for EOS
            if subunit is EOS:
                return subunit_buffer, cumulative_logw, cumulative_logp

            # Check boundary: is $\\bm{s} \\in \\mathcal{A}$ (complete unit)?
            if self.boundary_predicate(unit_context, subunit_buffer):
                # Let the predicate finalize the unit (e.g., remove delimiter tokens)
                unit = self.boundary_predicate.finalize_unit(subunit_buffer)
                return unit, cumulative_logw, cumulative_logp

        # Max subunits exceeded: we return -inf weight to reject incomplete/invalid unit
        return subunit_buffer, float("-inf"), cumulative_logp

    async def cleanup(self):
        """Clean up resources."""
        await self.subunit_sampler.cleanup()

start_weight() async

Return \(\overrightarrow{\psi}(\epsilon)\) (prefix weight of empty sequence).

Source code in genlm/control/sampler/unit.py
async def start_weight(self):
    """Return $\\overrightarrow{\\psi}(\\epsilon)$ (prefix weight of empty sequence)."""
    return await self.subunit_sampler.start_weight()

forward() async

Called by LLaMPPL Model.call() to sample one multi-token unit.

Called by SequenceModel.step() when it calls self.call(unit_sampler).

Source code in genlm/control/sampler/unit.py
async def forward(self):
    """Called by LLaMPPL Model.call() to sample one multi-token unit.

    Called by SequenceModel.step() when it calls self.call(unit_sampler).
    """
    parent = self.parent

    # Flatten parent.token_ctx before passing to sample
    # This ensures sample() always works with a flat list
    flat_context = flatten_units(parent.token_ctx)

    # Sample multi-token unit, passing both flat context and structured unit context
    unit, logw, logp = await self.sample(
        flat_context, unit_context=parent.token_ctx, draw=None
    )

    # Update parent's weight and logp
    parent.score(logw)
    parent.logp += logp

    # If the unit ends with EOS, return EOS directly so SequenceModel can detect completion
    # SequenceModel.step() checks `token_ctx[-1] is EOS` to finish generation
    if unit and unit[-1] is EOS:
        # Keep the unit content before EOS in token_ctx, then return EOS separately
        if len(unit) > 1:
            parent.token_ctx.append(unit[:-1])  # Add unit without EOS
        return EOS  # Return EOS directly for SequenceModel to detect

    return unit

sample(flat_token_context, unit_context=None, draw=None) async

Sample a multi-token unit by running sequence sampling for \(\varphi_{\bm{x}}\). SIS for the localized potential:

  1. Repeatedly sample \((s_i, w_i) \sim q_{\text{sub}}(\cdot \mid \bm{s}_{<i})\) until boundary
  2. Accumulate weights: \(w = \overrightarrow{\psi}_{\bm{x}}(\epsilon) \prod_i w_i\)
  3. Return \((\bm{s}, w)\) where \(\bm{s} \in \mathcal{B}^*\) forms unit \(x \in \mathcal{A}\)

Parameters:

Name Type Description Default
flat_token_context list

Flat sequence of all previously sampled tokens. This is pre-flattened by forward() to ensure compatibility with potentials.

required
unit_context list

Structured sequence of previously sampled units. Used by boundary predicates that need context. Defaults to [].

None
draw callable

Sampling function passed to subunit_sampler

None

Returns:

Type Description
(unit, weight, logp)
  • unit: List of subunits \([s_1, \ldots, s_n]\) forming \(x \in \mathcal{A}\)
  • weight: Importance weight \(w\) such that \((\text{unit}, w)\) is properly weighted w.r.t. \(\psi(x \mid \bm{x})\)
  • logp: Sum of log-probabilities of sampling choices
Source code in genlm/control/sampler/unit.py
async def sample(self, flat_token_context, unit_context=None, draw=None):
    """Sample a multi-token unit by running sequence sampling for $\\varphi_{\\bm{x}}$.
    SIS for the localized potential:

    1. Repeatedly sample $(s_i, w_i) \\sim q_{\\text{sub}}(\\cdot \\mid \\bm{s}_{<i})$ until boundary
    2. Accumulate weights: $w = \\overrightarrow{\\psi}_{\\bm{x}}(\\epsilon) \\prod_i w_i$
    3. Return $(\\bm{s}, w)$ where $\\bm{s} \\in \\mathcal{B}^*$ forms unit $x \\in \\mathcal{A}$

    Args:
        flat_token_context (list): Flat sequence of all previously sampled tokens.
            This is pre-flattened by forward() to ensure compatibility with potentials.
        unit_context (list, optional): Structured sequence of previously sampled units.
            Used by boundary predicates that need context. Defaults to [].
        draw (callable, optional): Sampling function passed to subunit_sampler

    Returns:
        (unit, weight, logp):
            - unit: List of subunits $[s_1, \\ldots, s_n]$ forming $x \\in \\mathcal{A}$
            - weight: Importance weight $w$ such that $(\\text{unit}, w)$ is properly
                weighted w.r.t. $\\psi(x \\mid \\bm{x})$
            - logp: Sum of log-probabilities of sampling choices
    """
    if unit_context is None:
        unit_context = []

    subunit_buffer = []
    current_context = list(flat_token_context)

    # Accumulate weights
    cumulative_logw = 0.0
    cumulative_logp = 0.0

    # Sequential sampling until EOT
    for _ in range(self.max_subunits_per_unit):
        # Sample next subunit $(s_i, w_i) \\sim q_{\\text{sub}}(\\cdot \\mid \\bm{s}_{<i})$
        try:
            subunit, logw_i, logp_i = await self.subunit_sampler.sample(
                current_context, draw
            )
        except (RuntimeError, OSError, TimeoutError):
            # Expected failures (network, timeout, system errors)
            # Return current buffer with -inf weight to discard this sample
            return subunit_buffer, float("-inf"), cumulative_logp

        # Accumulate weight and logp
        cumulative_logw += logw_i
        cumulative_logp += logp_i

        # Add to both buffer and context
        subunit_buffer.append(subunit)
        current_context.append(subunit)

        # Check for EOS
        if subunit is EOS:
            return subunit_buffer, cumulative_logw, cumulative_logp

        # Check boundary: is $\\bm{s} \\in \\mathcal{A}$ (complete unit)?
        if self.boundary_predicate(unit_context, subunit_buffer):
            # Let the predicate finalize the unit (e.g., remove delimiter tokens)
            unit = self.boundary_predicate.finalize_unit(subunit_buffer)
            return unit, cumulative_logw, cumulative_logp

    # Max subunits exceeded: we return -inf weight to reject incomplete/invalid unit
    return subunit_buffer, float("-inf"), cumulative_logp

cleanup() async

Clean up resources.

Source code in genlm/control/sampler/unit.py
async def cleanup(self):
    """Clean up resources."""
    await self.subunit_sampler.cleanup()

BoundaryPredicate

Bases: ABC

Abstract base class for boundary predicates.

A boundary predicate determines when a sequence of subunits \(\bm{s} \in \mathcal{B}^*\) forms a complete unit \(x \in \mathcal{A}\).

__call__ method receives unit context and subunit buffer, allowing predicates to be stateless and context-aware.

finalize_unit method transforms the buffer into the final unit after boundary detection, allowing predicates to control what tokens are included (e.g., removing delimiter tokens).

Source code in genlm/control/sampler/unit.py
class BoundaryPredicate(ABC):
    """Abstract base class for boundary predicates.

    A boundary predicate determines when a sequence of subunits $\\bm{s} \\in \\mathcal{B}^*$
    forms a complete unit $x \\in \\mathcal{A}$.

    `__call__` method receives unit context and subunit buffer, allowing predicates
    to be stateless and context-aware.

    `finalize_unit` method transforms the buffer into the final unit after boundary
    detection, allowing predicates to control what tokens are included (e.g., removing
    delimiter tokens).
    """

    @abstractmethod
    def __call__(self, unit_context: list, subunit_buffer: list) -> bool:
        """Check if subunit buffer forms a complete unit.

        Args:
            unit_context (list): Sequence of completed units $\\bm{x} \\in \\mathcal{A}^*$
            subunit_buffer (list): Current sequence of subunits $\\bm{s} \\in \\mathcal{B}^*$

        Returns:
            bool: True if $\\bm{s}$ forms a complete unit $x \\in \\mathcal{A}$
        """
        pass  # pragma: no cover

    def finalize_unit(self, subunit_buffer: list) -> list:
        """Transform buffer into final unit after boundary detected.

        Called after `__call__` returns True. Override to customize which tokens
        are included in the final unit (e.g., to remove delimiter tokens).

        Args:
            subunit_buffer (list): The buffer that triggered the boundary

        Returns:
            list: The final unit to return

        Note:
            Default implementation returns the entire buffer unchanged.
        """
        return subunit_buffer

__call__(unit_context, subunit_buffer) abstractmethod

Check if subunit buffer forms a complete unit.

Parameters:

Name Type Description Default
unit_context list

Sequence of completed units \(\bm{x} \in \mathcal{A}^*\)

required
subunit_buffer list

Current sequence of subunits \(\bm{s} \in \mathcal{B}^*\)

required

Returns:

Name Type Description
bool bool

True if \(\bm{s}\) forms a complete unit \(x \in \mathcal{A}\)

Source code in genlm/control/sampler/unit.py
@abstractmethod
def __call__(self, unit_context: list, subunit_buffer: list) -> bool:
    """Check if subunit buffer forms a complete unit.

    Args:
        unit_context (list): Sequence of completed units $\\bm{x} \\in \\mathcal{A}^*$
        subunit_buffer (list): Current sequence of subunits $\\bm{s} \\in \\mathcal{B}^*$

    Returns:
        bool: True if $\\bm{s}$ forms a complete unit $x \\in \\mathcal{A}$
    """
    pass  # pragma: no cover

finalize_unit(subunit_buffer)

Transform buffer into final unit after boundary detected.

Called after __call__ returns True. Override to customize which tokens are included in the final unit (e.g., to remove delimiter tokens).

Parameters:

Name Type Description Default
subunit_buffer list

The buffer that triggered the boundary

required

Returns:

Name Type Description
list list

The final unit to return

Note

Default implementation returns the entire buffer unchanged.

Source code in genlm/control/sampler/unit.py
def finalize_unit(self, subunit_buffer: list) -> list:
    """Transform buffer into final unit after boundary detected.

    Called after `__call__` returns True. Override to customize which tokens
    are included in the final unit (e.g., to remove delimiter tokens).

    Args:
        subunit_buffer (list): The buffer that triggered the boundary

    Returns:
        list: The final unit to return

    Note:
        Default implementation returns the entire buffer unchanged.
    """
    return subunit_buffer

TokenSetBoundary

Bases: BoundaryPredicate

Stateless boundary predicate based on token membership.

A unit is complete when the last subunit is in a specified set of boundary tokens.

Parameters:

Name Type Description Default
boundary_tokens Iterable

Set or iterable of tokens that mark unit boundaries

required
Example

boundary = TokenSetBoundary({b" ", b"\n"}) boundary([], [b"hello", b" "]) # True (ends with whitespace)

Unit will be [b"hello", b" "] - boundary token included

Source code in genlm/control/sampler/unit.py
class TokenSetBoundary(BoundaryPredicate):
    """Stateless boundary predicate based on token membership.

    A unit is complete when the last subunit is in a specified set of boundary tokens.

    Args:
        boundary_tokens: Set or iterable of tokens that mark unit boundaries

    Example:
        >>> boundary = TokenSetBoundary({b" ", b"\\n"})
        >>> boundary([], [b"hello", b" "])  # True (ends with whitespace)
        >>> # Unit will be [b"hello", b" "] - boundary token included
    """

    def __init__(self, boundary_tokens: Iterable):
        self.boundary_tokens = set(boundary_tokens)

    def __call__(self, unit_context: list, subunit_buffer: list) -> bool:
        """Check boundary (ignore unit_context for stateless predicate)."""
        return bool(subunit_buffer and subunit_buffer[-1] in self.boundary_tokens)

    def __repr__(self) -> str:
        return f"TokenSetBoundary({self.boundary_tokens!r})"

__call__(unit_context, subunit_buffer)

Check boundary (ignore unit_context for stateless predicate).

Source code in genlm/control/sampler/unit.py
def __call__(self, unit_context: list, subunit_buffer: list) -> bool:
    """Check boundary (ignore unit_context for stateless predicate)."""
    return bool(subunit_buffer and subunit_buffer[-1] in self.boundary_tokens)

FixedLengthBoundary

Bases: BoundaryPredicate

Stateless boundary predicate based on fixed unit length. A unit is complete when it reaches a specified number of subunits.

Parameters:

Name Type Description Default
length int

Number of subunits per unit

required
Example

boundary = FixedLengthBoundary(10) boundary([], [b"a"] * 9) # False boundary([], [b"a"] * 10) # True

Source code in genlm/control/sampler/unit.py
class FixedLengthBoundary(BoundaryPredicate):
    """Stateless boundary predicate based on fixed unit length.
    A unit is complete when it reaches a specified number of subunits.

    Args:
        length (int): Number of subunits per unit

    Example:
        >>> boundary = FixedLengthBoundary(10)
        >>> boundary([], [b"a"] * 9)   # False
        >>> boundary([], [b"a"] * 10)  # True
    """

    def __init__(self, length: int):
        if length <= 0:
            raise ValueError(f"Length must be positive, got {length}")
        self.length = length

    def __call__(self, unit_context: list, subunit_buffer: list) -> bool:
        """Check boundary (ignores unit_context for stateless predicate)."""
        return len(subunit_buffer) >= self.length

    def __repr__(self) -> str:
        return f"FixedLengthBoundary({self.length})"

__call__(unit_context, subunit_buffer)

Check boundary (ignores unit_context for stateless predicate).

Source code in genlm/control/sampler/unit.py
def __call__(self, unit_context: list, subunit_buffer: list) -> bool:
    """Check boundary (ignores unit_context for stateless predicate)."""
    return len(subunit_buffer) >= self.length

CFGBoundary

Bases: BoundaryPredicate

Boundary predicate using Lark parser for context-free grammar-based boundaries.

This uses Lark's parser to determine when a sequence of subunits forms a syntactically complete unit according to a context-free grammar.

A unit can be marked as complete when: - The subunit buffer parses successfully - The parse tree's root matches one of the complete_rules

Parameters:

Name Type Description Default
grammar_text str

Lark grammar specification

required
start_rule str

Starting rule for parsing (default: "start")

'start'
complete_rules set or None

Set of rule names that constitute complete units. If None, any successful parse is complete. If provided, only parses with matching root are complete.

None
min_length int

Minimum buffer length before attempting to parse (default: 2)

2
parser_type str

Lark parser type: 'earley' (default, supports ambiguity) or 'lalr' (faster)

'earley'
ambiguity str

How to handle ambiguous grammars: 'explicit' (default) or 'resolve'

'explicit'
encoding str

Text encoding for token decoding (default: "utf-8")

'utf-8'
decode_errors str

How to handle decode errors (default: "ignore")

'ignore'
Example

Simple arithmetic grammar

grammar = ''' ... start: expr ... expr: term | expr "+" term ... term: NUMBER ... NUMBER: /[0-9]+/ ... ''' boundary = CFGBoundary(grammar, complete_rules={"start"}) boundary([], [b"1", b"+", b"2"]) # True (complete expression) boundary([], [b"1", b"+"]) # False (incomplete)

Source code in genlm/control/sampler/unit.py
class CFGBoundary(BoundaryPredicate):
    """Boundary predicate using Lark parser for context-free grammar-based boundaries.

    This uses Lark's parser to determine when a sequence of subunits forms a
    syntactically complete unit according to a context-free grammar.

    A unit can be marked as complete when:
    - The subunit buffer parses successfully
    - The parse tree's root matches one of the complete_rules

    Args:
        grammar_text (str): Lark grammar specification
        start_rule (str): Starting rule for parsing (default: "start")
        complete_rules (set or None): Set of rule names that constitute complete units.
                                      If None, any successful parse is complete.
                                      If provided, only parses with matching root are complete.
        min_length (int): Minimum buffer length before attempting to parse (default: 2)
        parser_type (str): Lark parser type: 'earley' (default, supports ambiguity) or 'lalr' (faster)
        ambiguity (str): How to handle ambiguous grammars: 'explicit' (default) or 'resolve'
        encoding (str): Text encoding for token decoding (default: "utf-8")
        decode_errors (str): How to handle decode errors (default: "ignore")

    Example:
        >>> # Simple arithmetic grammar
        >>> grammar = '''
        ...     start: expr
        ...     expr: term | expr "+" term
        ...     term: NUMBER
        ...     NUMBER: /[0-9]+/
        ... '''
        >>> boundary = CFGBoundary(grammar, complete_rules={"start"})
        >>> boundary([], [b"1", b"+", b"2"])  # True (complete expression)
        >>> boundary([], [b"1", b"+"])        # False (incomplete)
    """

    def __init__(
        self,
        grammar_text,
        start_rule="start",
        complete_rules=None,
        min_length=2,
        parser_type="earley",
        ambiguity="explicit",
        encoding="utf-8",
        decode_errors="ignore",
    ):
        self.grammar_text = grammar_text
        self.start_rule = start_rule
        self.complete_rules = set(complete_rules) if complete_rules else None
        self.min_length = min_length
        self.encoding = encoding
        self.decode_errors = decode_errors
        try:
            if parser_type == "earley":
                self.parser = Lark(
                    grammar_text,
                    start=start_rule,
                    parser=parser_type,
                    ambiguity=ambiguity,
                )
            else:
                self.parser = Lark(grammar_text, start=start_rule, parser=parser_type)
        except Exception as e:
            raise ValueError(f"Failed to create Lark parser: {e}") from e

    def __call__(self, unit_context: list, subunit_buffer: list) -> bool:
        """Check if buffer forms a complete syntactic unit.

        Args:
            unit_context: Previous completed units (ignored)
            subunit_buffer: Current sequence of subunits to check

        Returns:
            bool: True if buffer parses successfully and meets criteria
        """
        if not subunit_buffer or len(subunit_buffer) < self.min_length:
            return False

        try:
            text = self._tokens_to_text(subunit_buffer)

            if not text or len(text) < self.min_length:
                return False

            tree = self.parser.parse(text)
            if self.complete_rules is None:
                return True

            root_rule = tree.data
            return root_rule in self.complete_rules

        except LarkError:
            # Parse failed: not a complete unit
            return False

    def _tokens_to_text(self, tokens: list) -> str:
        """Convert token buffer to text string.

        Args:
            tokens: List of tokens (bytes objects or lists)

        Returns:
            str: Decoded text
        """
        # Join byte tokens, filtering out EOS
        token_bytes = b"".join(
            t for t in tokens if isinstance(t, bytes) and t is not EOS
        )
        return token_bytes.decode(self.encoding, errors=self.decode_errors)

    def get_parse_tree(self, text: str) -> Optional[Any]:
        """Get the parse tree for a given text.

        Args:
            text (str): String to parse

        Returns:
            Lark Tree object or None if parsing fails
        """
        try:
            return self.parser.parse(text)
        except LarkError:
            return None

    def __repr__(self) -> str:
        rules_str = (
            f", complete_rules={self.complete_rules}" if self.complete_rules else ""
        )
        return f"CFGBoundary(start={self.start_rule!r}{rules_str})"

__call__(unit_context, subunit_buffer)

Check if buffer forms a complete syntactic unit.

Parameters:

Name Type Description Default
unit_context list

Previous completed units (ignored)

required
subunit_buffer list

Current sequence of subunits to check

required

Returns:

Name Type Description
bool bool

True if buffer parses successfully and meets criteria

Source code in genlm/control/sampler/unit.py
def __call__(self, unit_context: list, subunit_buffer: list) -> bool:
    """Check if buffer forms a complete syntactic unit.

    Args:
        unit_context: Previous completed units (ignored)
        subunit_buffer: Current sequence of subunits to check

    Returns:
        bool: True if buffer parses successfully and meets criteria
    """
    if not subunit_buffer or len(subunit_buffer) < self.min_length:
        return False

    try:
        text = self._tokens_to_text(subunit_buffer)

        if not text or len(text) < self.min_length:
            return False

        tree = self.parser.parse(text)
        if self.complete_rules is None:
            return True

        root_rule = tree.data
        return root_rule in self.complete_rules

    except LarkError:
        # Parse failed: not a complete unit
        return False

get_parse_tree(text)

Get the parse tree for a given text.

Parameters:

Name Type Description Default
text str

String to parse

required

Returns:

Type Description
Optional[Any]

Lark Tree object or None if parsing fails

Source code in genlm/control/sampler/unit.py
def get_parse_tree(self, text: str) -> Optional[Any]:
    """Get the parse tree for a given text.

    Args:
        text (str): String to parse

    Returns:
        Lark Tree object or None if parsing fails
    """
    try:
        return self.parser.parse(text)
    except LarkError:
        return None

flatten_units(context)

Flatten nested unit context to a flat token list. When using MultiTokenUnitSampler, token_ctx becomes nested [[...], [...], ...]. This helper flattens it for use with coercion functions like b"".join.

Usage

potential.coerce(LLM, f=lambda ctx: b"".join(flatten_units(ctx)))

Args: context: Either a flat list [token1, token2, ...] or nested [[token1, token2], [token3], ...] Returns: list: Flattened list of tokens

Source code in genlm/control/sampler/unit.py
def flatten_units(context):
    """
    Flatten nested unit context to a flat token list. When using MultiTokenUnitSampler, token_ctx becomes nested [[...], [...], ...].
    This helper flattens it for use with coercion functions like b"".join.

    Usage:
        potential.coerce(LLM, f=lambda ctx: b"".join(flatten_units(ctx)))
    Args:
        context: Either a flat list [token1, token2, ...] or nested [[token1, token2], [token3], ...]
    Returns:
        list: Flattened list of tokens
    """
    flattened = []
    for item in context:
        if isinstance(item, list):
            flattened.extend(item)
        else:
            flattened.append(item)
    return flattened

direct_token_sampler(potential)

Create a DirectTokenSampler that samples directly from a potential's vocabulary.

See DirectTokenSampler for more details.

Parameters:

Name Type Description Default
potential Potential

The potential function to sample from. Should have an efficient logw_next method.

required

Returns:

Type Description
DirectTokenSampler

A sampler that directly samples tokens from the potential's vocabulary.

Source code in genlm/control/sampler/__init__.py
def direct_token_sampler(potential):
    """Create a `DirectTokenSampler` that samples directly from a potential's vocabulary.

    See `DirectTokenSampler` for more details.

    Args:
        potential (Potential): The potential function to sample from. Should have an efficient logw_next method.

    Returns:
        (DirectTokenSampler): A sampler that directly samples tokens from the potential's vocabulary.
    """
    assert isinstance(potential, Potential)
    return DirectTokenSampler(potential)

eager_token_sampler(iter_potential, item_potential)

Create a SetTokenSampler that uses the EagerSetSampler to sample a set of tokens.

See EagerSetSampler for more details.

Parameters:

Name Type Description Default
iter_potential Potential

A potential function defined over a vocabulary of iterables.

required
item_potential Potential

A potential function defined over a vocabulary of items which are elements of the iterables.

required

Returns:

Type Description
SetTokenSampler

A sampler that wraps an EagerSetSampler.

Note

This is the fastest sampler in most cases.

Source code in genlm/control/sampler/__init__.py
def eager_token_sampler(iter_potential, item_potential):
    """Create a `SetTokenSampler` that uses the `EagerSetSampler` to sample a set of tokens.

    See `EagerSetSampler` for more details.

    Args:
        iter_potential (Potential): A potential function defined over a vocabulary of iterables.
        item_potential (Potential): A potential function defined over a vocabulary of items which are elements of the iterables.

    Returns:
        (SetTokenSampler): A sampler that wraps an `EagerSetSampler`.

    Note:
        This is the fastest sampler in most cases.
    """
    return SetTokenSampler(EagerSetSampler(iter_potential, item_potential))

topk_token_sampler(iter_potential, item_potential, K)

Create a SetTokenSampler that uses the TopKSetSampler to sample a set of tokens.

See TopKSetSampler for more details.

Parameters:

Name Type Description Default
iter_potential Potential

A potential function defined over a vocabulary of iterables.

required
item_potential Potential

A potential function defined over a vocabulary of items which are elements of the iterables.

required
K int | None

The K parameter for the TopKSetSampler.

required

Returns:

Type Description
SetTokenSampler

A sampler that wraps an TopKSetSampler.

Source code in genlm/control/sampler/__init__.py
def topk_token_sampler(iter_potential, item_potential, K):
    """Create a `SetTokenSampler` that uses the `TopKSetSampler` to sample a set of tokens.

    See `TopKSetSampler` for more details.

    Args:
        iter_potential (Potential): A potential function defined over a vocabulary of iterables.
        item_potential (Potential): A potential function defined over a vocabulary of items which are elements of the iterables.
        K (int|None): The `K` parameter for the `TopKSetSampler`.

    Returns:
        (SetTokenSampler): A sampler that wraps an `TopKSetSampler`.
    """
    return SetTokenSampler(TopKSetSampler(iter_potential, item_potential, K))