Skip to content

sampler

DirectTokenSampler

Bases: TokenSampler

Samples individual tokens directly from the log-normalized logw_next function of a potential.

Parameters:

Name Type Description Default
potential Potential

The potential function to sample from

required
Warning

Only use this sampler if the potential's logw_next method is efficient. This is the case for potentials like PromptedLLM, but for custom potentials with a large vocabulary size, the default implementation of logw_next generally will not be efficient, and thus this sampler will be slow.

Source code in genlm/control/sampler/token.py
class DirectTokenSampler(TokenSampler):
    """Samples individual tokens directly from the log-normalized `logw_next` function
    of a potential.

    Args:
        potential (Potential): The potential function to sample from

    Warning:
        Only use this sampler if the potential's `logw_next` method is efficient. This is the case
        for potentials like `PromptedLLM`, but for custom potentials with a large vocabulary size,
        the default implementation of `logw_next` generally will not be efficient, and thus this
        sampler will be slow.
    """

    def __init__(self, potential):
        super().__init__(target=potential)
        self.potential = potential

    async def sample(self, context, draw=None):
        """Sample a token and weight that are properly weighted with respect to the target potential's `logw_next` method.

        Given a context of tokens $x_1, \\ldots, x_{n-1}$ in the target potential's vocabulary,
        this method samples a token $x_n \\in \\textsf{target.vocab_eos}$ and weight $w$.

        The sampled token and weight are properly weighted with respect to
        $$
        \\textsf{target.logw_next}(x_n | x_1, \\ldots, x_{n-1})
        $$

        The returned weight corresponds to the log normalizing constant of $\\textsf{target.logw_next}(x_n | x_1, \\ldots, x_{n-1})$.

        Returns:
            (token, weight, logp): A tuple containing the sampled token, weight, and log-probability of the sampled token.
        """
        logws = await self.potential.logw_next(context)
        logps = logws.normalize()
        if draw is None:
            # fast sampling from logps using gumbel-max trick
            token = fast_sample_lazyweights(logps)
        else:
            token = draw(logps.exp().materialize())
        return token, logws.sum(), logps[token]

    async def cleanup(self):
        pass  # pragma: no cover

sample(context, draw=None) async

Sample a token and weight that are properly weighted with respect to the target potential's logw_next method.

Given a context of tokens \(x_1, \ldots, x_{n-1}\) in the target potential's vocabulary, this method samples a token \(x_n \in \textsf{target.vocab_eos}\) and weight \(w\).

The sampled token and weight are properly weighted with respect to $$ \textsf{target.logw_next}(x_n | x_1, \ldots, x_{n-1}) $$

The returned weight corresponds to the log normalizing constant of \(\textsf{target.logw_next}(x_n | x_1, \ldots, x_{n-1})\).

Returns:

Type Description
(token, weight, logp)

A tuple containing the sampled token, weight, and log-probability of the sampled token.

Source code in genlm/control/sampler/token.py
async def sample(self, context, draw=None):
    """Sample a token and weight that are properly weighted with respect to the target potential's `logw_next` method.

    Given a context of tokens $x_1, \\ldots, x_{n-1}$ in the target potential's vocabulary,
    this method samples a token $x_n \\in \\textsf{target.vocab_eos}$ and weight $w$.

    The sampled token and weight are properly weighted with respect to
    $$
    \\textsf{target.logw_next}(x_n | x_1, \\ldots, x_{n-1})
    $$

    The returned weight corresponds to the log normalizing constant of $\\textsf{target.logw_next}(x_n | x_1, \\ldots, x_{n-1})$.

    Returns:
        (token, weight, logp): A tuple containing the sampled token, weight, and log-probability of the sampled token.
    """
    logws = await self.potential.logw_next(context)
    logps = logws.normalize()
    if draw is None:
        # fast sampling from logps using gumbel-max trick
        token = fast_sample_lazyweights(logps)
    else:
        token = draw(logps.exp().materialize())
    return token, logws.sum(), logps[token]

SetTokenSampler

Bases: TokenSampler

Samples individual tokens by sampling a weighted set of tokens and then selecting one proportional to its weight.

This class wraps a SetSampler.

Parameters:

Name Type Description Default
set_sampler SetSampler

The set sampler to sample from

required
Source code in genlm/control/sampler/token.py
class SetTokenSampler(TokenSampler):
    """Samples individual tokens by sampling a weighted set of tokens and then selecting one
    proportional to its weight.

    This class wraps a `SetSampler`.

    Args:
        set_sampler (SetSampler): The set sampler to sample from
    """

    def __init__(self, set_sampler):
        assert isinstance(set_sampler, SetSampler)
        super().__init__(set_sampler.target)
        self.set_sampler = set_sampler

    async def sample(self, context, draw=None):
        """Sample a token and weight by sampling a weighted set of tokens from the `set_sampler`
        and then selecting one proportional to its weight.

        Given a context of tokens $x_1, \\ldots, x_{n-1}$ in the vocabulary of the set sampler's target potential,
        this method samples a token $x_n \\in \\textsf{set_sampler.target.vocab_eos}$ and a weight.

        The sampled token and weight are properly weighted with respect to
        $$
        \\textsf{set_sampler.target.logw_next}(x_n | x_1, \\ldots, x_{n-1})
        $$

        The returned weight corresponds to the sum of the weights of the sampled set.

        Args:
            context (list[int]): A sequence of tokens in the vocabulary of the set sampler's target potential.

        Returns:
            (token, weight, logp): A tuple containing the sampled token, weight, and log-probability of the random
                choices made in sampling that token.

        Note:
            For properly weighted sampling, the `set_sampler` must assign correct weights to each token. See
            `SetSampler` for more details.
        """
        logws, logp = await self.set_sampler.sample_set(context, draw=draw)
        logps = logws.normalize()
        if draw is None:
            token = fast_sample_lazyweights(logps)
        else:
            token = draw(logps.exp().materialize())
        return token, logws.sum(), logp + logps[token]

    async def cleanup(self):
        """Clean up the sampler.

        This method should be called when the sampler is no longer needed.
        """
        await self.set_sampler.cleanup()

sample(context, draw=None) async

Sample a token and weight by sampling a weighted set of tokens from the set_sampler and then selecting one proportional to its weight.

Given a context of tokens \(x_1, \ldots, x_{n-1}\) in the vocabulary of the set sampler's target potential, this method samples a token \(x_n \in \textsf{set_sampler.target.vocab_eos}\) and a weight.

The sampled token and weight are properly weighted with respect to $$ \textsf{set_sampler.target.logw_next}(x_n | x_1, \ldots, x_{n-1}) $$

The returned weight corresponds to the sum of the weights of the sampled set.

Parameters:

Name Type Description Default
context list[int]

A sequence of tokens in the vocabulary of the set sampler's target potential.

required

Returns:

Type Description
(token, weight, logp)

A tuple containing the sampled token, weight, and log-probability of the random choices made in sampling that token.

Note

For properly weighted sampling, the set_sampler must assign correct weights to each token. See SetSampler for more details.

Source code in genlm/control/sampler/token.py
async def sample(self, context, draw=None):
    """Sample a token and weight by sampling a weighted set of tokens from the `set_sampler`
    and then selecting one proportional to its weight.

    Given a context of tokens $x_1, \\ldots, x_{n-1}$ in the vocabulary of the set sampler's target potential,
    this method samples a token $x_n \\in \\textsf{set_sampler.target.vocab_eos}$ and a weight.

    The sampled token and weight are properly weighted with respect to
    $$
    \\textsf{set_sampler.target.logw_next}(x_n | x_1, \\ldots, x_{n-1})
    $$

    The returned weight corresponds to the sum of the weights of the sampled set.

    Args:
        context (list[int]): A sequence of tokens in the vocabulary of the set sampler's target potential.

    Returns:
        (token, weight, logp): A tuple containing the sampled token, weight, and log-probability of the random
            choices made in sampling that token.

    Note:
        For properly weighted sampling, the `set_sampler` must assign correct weights to each token. See
        `SetSampler` for more details.
    """
    logws, logp = await self.set_sampler.sample_set(context, draw=draw)
    logps = logws.normalize()
    if draw is None:
        token = fast_sample_lazyweights(logps)
    else:
        token = draw(logps.exp().materialize())
    return token, logws.sum(), logp + logps[token]

cleanup() async

Clean up the sampler.

This method should be called when the sampler is no longer needed.

Source code in genlm/control/sampler/token.py
async def cleanup(self):
    """Clean up the sampler.

    This method should be called when the sampler is no longer needed.
    """
    await self.set_sampler.cleanup()

AWRS

Bases: TokenSampler

Samples individual tokens through an adaptive weighted rejection sampling algorithm.

This sampler is based on the algorithm described in Fast Controlled Generation from Language Models with Adaptive Weighted Rejection Sampling

It draws properly weighted samples from the product of a non-boolean potential and a boolean condition.

Parameters:

Name Type Description Default
potential Potential

The non-boolean potential.

required
condition Potential

The boolean condition. This potential must only output boolean values (0 or -inf in log-space).

required
seed int or None

The seed for the random number generator.

None
prune_logws bool

Whether to prune the logws to only include the tokens in the intersection of the potential and condition vocabularies

True
proper_weights bool

Whether to return properly weighted samples. If False, the sampler will only run one round of adaptive rejection sampling.

True
max_accepts int

The maximum number of tokens to accept - higher values will decrease the variance of the weight estimate.

2
max_rejects int or float('inf'

The maximum number of tokens to reject - lower values will run faster, but at the cost of returning a weight of zero for some samples where there are tokens that would be accepted if tested.

float('inf')
n_monte_carlo_samples int

The number of Monte Carlo samples to use to estimate the weight. Higher values will decrease the variance of the weight estimate, but will run slower.

None
Source code in genlm/control/sampler/token.py
class AWRS(TokenSampler):
    """Samples individual tokens through an adaptive weighted rejection sampling algorithm.

    This sampler is based on the algorithm described in [Fast Controlled Generation from Language Models with Adaptive Weighted Rejection Sampling](https://arxiv.org/abs/2504.05410)

    It draws properly weighted samples from the product of a non-boolean potential and a boolean condition.

    Args:
        potential (Potential): The non-boolean potential.
        condition (Potential): The boolean condition. This potential must only output boolean values (0 or -inf in log-space).
        seed (int or None): The seed for the random number generator.
        prune_logws (bool): Whether to prune the logws to only include the tokens in the intersection of the potential and condition vocabularies
        proper_weights (bool): Whether to return properly weighted samples.
            If False, the sampler will only run one round of adaptive rejection sampling.
        max_accepts (int): The maximum number of tokens to accept - higher values will decrease the variance of the weight estimate.
        max_rejects (int or float('inf')): The maximum number of tokens to reject - lower values will run faster, but at the cost of returning a weight of zero for some samples where there are tokens that would be accepted if tested.
        n_monte_carlo_samples (int): The number of Monte Carlo samples to use to estimate the weight. Higher values will decrease the variance of the weight estimate, but will run slower.
    """

    def __init__(
        self,
        potential,
        condition,
        seed=None,
        prune_logws=True,
        proper_weights=True,
        max_accepts=2,
        max_rejects=float("inf"),
        n_monte_carlo_samples=None,
    ):
        super().__init__(target=potential * condition)
        self.potential = potential
        self.condition = condition

        self.prune_logws = prune_logws
        self.proper_weights = proper_weights

        if max_accepts < 2 and proper_weights:
            raise ValueError("`max_accepts` must be at least 2")

        if max_rejects < 2 and proper_weights:
            raise ValueError("`max_rejects` must be at least 2")

        if n_monte_carlo_samples is not None:
            warnings.warn(
                "n_monte_carlo_samples no longer does anything.",
                DeprecationWarning,
            )

        self.max_accepts = max_accepts
        self.max_rejects = max_rejects or float("inf")

        self.valid_idxs = np.array(
            [self.potential.lookup[t] for t in self.target.vocab_eos]
        )

        self.vocab_eos_set = set(self.target.vocab_eos)
        self.V = len(self.potential.vocab_eos)
        self.rng = np.random.default_rng(seed=seed)

    def _prune_logws(self, logws):
        # Prune the logws to only include the tokens in the
        # target vocabulary. (This zeros-out tokens which we know a priori
        # will be rejected.) Note: We need an additional correction term
        # to account for the fact that we're throwing away some probability mass.
        # This should be handled in `sample`.
        pruned = self.potential.alloc_logws()
        pruned[self.valid_idxs] = logws.weights[self.valid_idxs]
        logws.weights = pruned
        return logws

    async def _accept(self, context, token, verbosity=0):
        if self.prune_logws or token in self.vocab_eos_set:
            if token is self.target.eos:
                logscore = await self.condition.complete(context)
            else:
                logscore = await self.condition.prefix(context + [token])
            assert logscore in {-np.inf, 0}, "`condition` must be Boolean"
        else:
            logscore = -np.inf

        do_accept = logscore == 0

        if verbosity > 0:
            if do_accept:
                print(colors.green % f". {repr(token)}")
            else:
                print(colors.red % ".", end="")

        return do_accept

    async def sample(self, context, verbosity=0):
        """Sample a token and weight that are properly weighted with respect to the target potential's `logw_next` method via adaptive weighted rejection sampling.

        The returned weight corresponds to the log normalizing constant of $\\textsf{target.logw_next}(x_n | x_1, \\ldots, x_{n-1})$.

        Returns:
            (token, weight, np.nan): A tuple containing the sampled token, weight, and a dummy value for the log-probability of the sampled token.
        """
        logws = await self.potential.logw_next(context)
        if self.prune_logws:
            logws = self._prune_logws(logws)

        logZ = logsumexp(logws.weights)
        logps = logws.weights - logZ
        toks = logws.decode

        # We cache successful calls, as algorithms may want to see the
        # same successful token more than once (currently just geometric_awrs)
        cache = {}

        async def accept(tok):
            try:
                return cache[tok]
            except KeyError:
                pass
            result = await self._accept(context, tok, verbosity)
            if result:
                cache[tok] = result
            return result

        if not self.proper_weights:
            return await improper_sample(
                logps=logps,
                toks=toks,
                accept=accept,
                rng=self.rng,
                max_rejects=self.max_rejects,
            )
        # We pick which algorithm to use based on parameters and the
        # shape of the distribution, as this lets us pick the most
        # effective option.
        elif (
            # If max_accepts is large then recursive_awrs (which
            # does not currently support this parameter) isn't very
            # useful, because the recursive step means that you never
            # revisit the same value, so will often throw away most
            # of the accepted mass if you were to continue. Also
            # this parameter is only really relevant if you want to
            # lower the variance, and geometric_awrs is lower variance.
            self.max_accepts > 2
            or
            # If the distribution is strongly peaked around a single value
            # then geometric_awrs will be more efficient. See below
            # for specific derivation.
            logps.max() >= GEOMETRIC_THRESHOLD
        ):
            tok, w, _ = await geometric_awrs(
                logps=logps,
                toks=toks,
                accept=accept,
                rng=self.rng,
                max_rejects=self.max_rejects,
                max_accepts=self.max_accepts,
            )
            return tok, w + logZ, np.nan
        else:
            tok, w, _ = await recursive_awrs(
                logps=logps,
                toks=toks,
                accept=accept,
                rng=self.rng,
                max_rejects=self.max_rejects,
            )
            return tok, w + logZ, np.nan

sample(context, verbosity=0) async

Sample a token and weight that are properly weighted with respect to the target potential's logw_next method via adaptive weighted rejection sampling.

The returned weight corresponds to the log normalizing constant of \(\textsf{target.logw_next}(x_n | x_1, \ldots, x_{n-1})\).

Returns:

Type Description
(token, weight, nan)

A tuple containing the sampled token, weight, and a dummy value for the log-probability of the sampled token.

Source code in genlm/control/sampler/token.py
async def sample(self, context, verbosity=0):
    """Sample a token and weight that are properly weighted with respect to the target potential's `logw_next` method via adaptive weighted rejection sampling.

    The returned weight corresponds to the log normalizing constant of $\\textsf{target.logw_next}(x_n | x_1, \\ldots, x_{n-1})$.

    Returns:
        (token, weight, np.nan): A tuple containing the sampled token, weight, and a dummy value for the log-probability of the sampled token.
    """
    logws = await self.potential.logw_next(context)
    if self.prune_logws:
        logws = self._prune_logws(logws)

    logZ = logsumexp(logws.weights)
    logps = logws.weights - logZ
    toks = logws.decode

    # We cache successful calls, as algorithms may want to see the
    # same successful token more than once (currently just geometric_awrs)
    cache = {}

    async def accept(tok):
        try:
            return cache[tok]
        except KeyError:
            pass
        result = await self._accept(context, tok, verbosity)
        if result:
            cache[tok] = result
        return result

    if not self.proper_weights:
        return await improper_sample(
            logps=logps,
            toks=toks,
            accept=accept,
            rng=self.rng,
            max_rejects=self.max_rejects,
        )
    # We pick which algorithm to use based on parameters and the
    # shape of the distribution, as this lets us pick the most
    # effective option.
    elif (
        # If max_accepts is large then recursive_awrs (which
        # does not currently support this parameter) isn't very
        # useful, because the recursive step means that you never
        # revisit the same value, so will often throw away most
        # of the accepted mass if you were to continue. Also
        # this parameter is only really relevant if you want to
        # lower the variance, and geometric_awrs is lower variance.
        self.max_accepts > 2
        or
        # If the distribution is strongly peaked around a single value
        # then geometric_awrs will be more efficient. See below
        # for specific derivation.
        logps.max() >= GEOMETRIC_THRESHOLD
    ):
        tok, w, _ = await geometric_awrs(
            logps=logps,
            toks=toks,
            accept=accept,
            rng=self.rng,
            max_rejects=self.max_rejects,
            max_accepts=self.max_accepts,
        )
        return tok, w + logZ, np.nan
    else:
        tok, w, _ = await recursive_awrs(
            logps=logps,
            toks=toks,
            accept=accept,
            rng=self.rng,
            max_rejects=self.max_rejects,
        )
        return tok, w + logZ, np.nan

EagerSetSampler

Bases: TrieSetSampler

A trie-based set sampler that implements an eager sampling strategy for generating a set of tokens.

An EagerSetSampler samples tokens by incrementally sampling items from the item-wise product of the iter_potential and item_potential. The sampled set is the set of sequences of items that correspond to valid tokens in iter_potential's vocabulary.

Source code in genlm/control/sampler/set.py
class EagerSetSampler(TrieSetSampler):
    """
    A trie-based set sampler that implements an eager sampling strategy
    for generating a set of tokens.

    An `EagerSetSampler` samples tokens by incrementally sampling items from the item-wise product of the `iter_potential` and `item_potential`.
    The sampled set is the set of sequences of items that correspond to valid tokens in `iter_potential`'s vocabulary.
    """

    async def sample_set(self, context, draw=None):
        """
        Sample a set of tokens given a context.

        Args:
            context (list): A sequence of tokens in the `iter_potential`'s vocabulary.

        Returns:
            (LazyWeights, float): A weighted set of tokens and the log-probability of the sampled set.
        """
        if draw is None:
            draw = sample_dict
        iter_logws = await self.iter_potential.logw_next(context)
        item_ws = await self.trie_executor.weight_sum(iter_logws.exp().weights)

        logws = self.target.alloc_logws()
        curr = self.trie.root
        coerced_ctx = self.f(context)
        subtokens = []
        logp, logw = 0, 0

        while True:
            children = self.trie.children[curr]
            item_w_curr = item_ws[curr]
            item_ws1 = Float.chart(
                {a: item_ws[c] / item_w_curr for a, c in children.items()}
            )

            if None in item_ws1:
                leaf = children[None]
                token = self.trie.leaf2word[leaf]
                token_id = self.leaf_to_token_id[leaf]
                logws[token_id] = iter_logws[token] + logw - logp

            item_logws2 = await self.item_potential.logw_next(coerced_ctx + subtokens)
            item_ws2 = item_logws2.exp().materialize()
            w_next = (item_ws1 * item_ws2).trim()

            if not w_next:
                break

            ps = w_next.normalize()
            b = draw(ps)
            logp += np.log(ps[b])
            logw += item_logws2[b]

            if b == self.target.eos:
                assert not subtokens, "subtokens should be empty at EOS."
                logws[-1] = iter_logws[self.target.eos] + logw - logp
                break

            subtokens.append(b)
            curr = children[b]

        return self.target.make_lazy_weights(logws), logp

sample_set(context, draw=None) async

Sample a set of tokens given a context.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the iter_potential's vocabulary.

required

Returns:

Type Description
(LazyWeights, float)

A weighted set of tokens and the log-probability of the sampled set.

Source code in genlm/control/sampler/set.py
async def sample_set(self, context, draw=None):
    """
    Sample a set of tokens given a context.

    Args:
        context (list): A sequence of tokens in the `iter_potential`'s vocabulary.

    Returns:
        (LazyWeights, float): A weighted set of tokens and the log-probability of the sampled set.
    """
    if draw is None:
        draw = sample_dict
    iter_logws = await self.iter_potential.logw_next(context)
    item_ws = await self.trie_executor.weight_sum(iter_logws.exp().weights)

    logws = self.target.alloc_logws()
    curr = self.trie.root
    coerced_ctx = self.f(context)
    subtokens = []
    logp, logw = 0, 0

    while True:
        children = self.trie.children[curr]
        item_w_curr = item_ws[curr]
        item_ws1 = Float.chart(
            {a: item_ws[c] / item_w_curr for a, c in children.items()}
        )

        if None in item_ws1:
            leaf = children[None]
            token = self.trie.leaf2word[leaf]
            token_id = self.leaf_to_token_id[leaf]
            logws[token_id] = iter_logws[token] + logw - logp

        item_logws2 = await self.item_potential.logw_next(coerced_ctx + subtokens)
        item_ws2 = item_logws2.exp().materialize()
        w_next = (item_ws1 * item_ws2).trim()

        if not w_next:
            break

        ps = w_next.normalize()
        b = draw(ps)
        logp += np.log(ps[b])
        logw += item_logws2[b]

        if b == self.target.eos:
            assert not subtokens, "subtokens should be empty at EOS."
            logws[-1] = iter_logws[self.target.eos] + logw - logp
            break

        subtokens.append(b)
        curr = children[b]

    return self.target.make_lazy_weights(logws), logp

TopKSetSampler

Bases: TrieSetSampler

A trie-based set sampler that lazily enumerates the top K tokens by weight in the target, and samples an additional "wildcard" token to ensure absolute continuity.

Warning

This sampler is not guaranteed to be correct if the item_potential's prefix weights do not monotonically decrease with the length of the context. That is, \(\textsf{item_potential.prefix}(x) \leq \textsf{item_potential.prefix}(xy)\) for all sequences of items \(x, y\).

Source code in genlm/control/sampler/set.py
class TopKSetSampler(TrieSetSampler):
    """
    A trie-based set sampler that lazily enumerates the top K tokens by weight in the target,
    and samples an additional "wildcard" token to ensure absolute continuity.

    Warning:
        This sampler is not guaranteed to be correct if the `item_potential`'s
        prefix weights do not monotonically decrease with the length of the context.
        That is, $\\textsf{item_potential.prefix}(x) \\leq \\textsf{item_potential.prefix}(xy)$ for all sequences of items $x, y$.
    """

    def __init__(self, iter_potential, item_potential, K):
        """
        Initialize the TopKSetSampler.

        Args:
            iter_potential (Potential): The potential defined over a vocabulary of iterables.
            item_potential (Potential): The potential defined over a vocabulary of items.
            K (int|None): The number of top tokens to enumerate. If None, all tokens are enumerated.
        """
        if K is not None and K <= 0:
            raise ValueError("K must be greater than 0 or None")
        super().__init__(iter_potential, item_potential)
        self.K = K

    async def sample_set(self, context, draw=None):
        """
        Sample a set of tokens given a context.

        Args:
            context (list): A sequence of tokens in the `iter_potential`'s vocabulary.

        Returns:
            (LazyWeights, float): A weighted set of tokens and the log-probability of the sampled set.
        """
        if draw is None:
            draw = sample_dict
        iter_logws = await self.iter_potential.logw_next(context)
        max_logws = await self.trie_executor.weight_max(iter_logws.weights)

        k = 0
        logws = self.target.alloc_logws()
        sampled = self.target.alloc_logws(default=False)

        async for token_id, logw in self._lazy_enum(context, max_logws):
            logws[token_id] = logw
            sampled[token_id] = True
            k += 1
            if self.K is not None and k >= self.K:
                break

        logp_wc = 0
        if self.K is not None and k == self.K:
            # Get the distribution over wildcard tokens
            iter_ws = iter_logws.exp()
            W_wc = Float.chart(
                {
                    token_id: iter_ws[token]
                    for token_id, token in enumerate(self.target.vocab_eos)
                    if not sampled[token_id]
                }
            )

            # if W_wc is non-empty, sample a wildcard token to ensure absolute continuity
            if W_wc:
                P_wc = W_wc.normalize()
                wc_id = draw(P_wc)
                logp_wc = np.log(P_wc[wc_id])
                wc = self.target.vocab_eos[wc_id]
                item_ctx = self.f(context)
                prefix_w = await self.item_potential.prefix(item_ctx)
                if wc == self.target.eos:
                    w_guide_wc = await self.item_potential.complete(item_ctx) - prefix_w
                else:
                    w_guide_wc = (
                        await self.item_potential.prefix(self.f(context + [wc]))
                        - prefix_w
                    )
                logws[wc_id] = np.log(W_wc[wc_id]) + w_guide_wc - logp_wc

        return self.target.make_lazy_weights(logws), logp_wc

    async def _lazy_enum(self, context, max_logws):
        agenda = LocatorMaxHeap()

        W = Float.chart()

        # initial conditions
        (token, node) = ((), self.trie.root)
        agenda[token, node, False] = max_logws[node]
        W[node] = 0

        children = self.trie.children
        coerced_ctx = self.f(context)

        curr_priority = float("inf")
        prev_best = float("inf")
        while agenda:
            (token, node, done), score = agenda.popitem()

            assert score <= curr_priority, (
                "Monotonicity assumption violated. "
                "`item_potential` prefix weight must be monotonically decreasing."
            )
            curr_priority = score

            # terminal state
            if done:
                value = W[node] + max_logws[node]
                assert prev_best >= value
                prev_best = value
                yield (self.leaf_to_token_id[node], value)
                continue

            logws = None
            for x, y in children[node].items():
                if x is None:
                    W_y = W[node]
                    W[y] = W_y
                    agenda[token, y, True] = W_y + max_logws[y]
                else:
                    if logws is None:
                        logws = await self.item_potential.logw_next(
                            coerced_ctx + list(token)
                        )
                    W_y = W[node] + logws[x]
                    if W_y == float("-inf"):
                        continue
                    W[y] = W_y
                    agenda[(*token, x), y, False] = W_y + max_logws[y]

__init__(iter_potential, item_potential, K)

Initialize the TopKSetSampler.

Parameters:

Name Type Description Default
iter_potential Potential

The potential defined over a vocabulary of iterables.

required
item_potential Potential

The potential defined over a vocabulary of items.

required
K int | None

The number of top tokens to enumerate. If None, all tokens are enumerated.

required
Source code in genlm/control/sampler/set.py
def __init__(self, iter_potential, item_potential, K):
    """
    Initialize the TopKSetSampler.

    Args:
        iter_potential (Potential): The potential defined over a vocabulary of iterables.
        item_potential (Potential): The potential defined over a vocabulary of items.
        K (int|None): The number of top tokens to enumerate. If None, all tokens are enumerated.
    """
    if K is not None and K <= 0:
        raise ValueError("K must be greater than 0 or None")
    super().__init__(iter_potential, item_potential)
    self.K = K

sample_set(context, draw=None) async

Sample a set of tokens given a context.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the iter_potential's vocabulary.

required

Returns:

Type Description
(LazyWeights, float)

A weighted set of tokens and the log-probability of the sampled set.

Source code in genlm/control/sampler/set.py
async def sample_set(self, context, draw=None):
    """
    Sample a set of tokens given a context.

    Args:
        context (list): A sequence of tokens in the `iter_potential`'s vocabulary.

    Returns:
        (LazyWeights, float): A weighted set of tokens and the log-probability of the sampled set.
    """
    if draw is None:
        draw = sample_dict
    iter_logws = await self.iter_potential.logw_next(context)
    max_logws = await self.trie_executor.weight_max(iter_logws.weights)

    k = 0
    logws = self.target.alloc_logws()
    sampled = self.target.alloc_logws(default=False)

    async for token_id, logw in self._lazy_enum(context, max_logws):
        logws[token_id] = logw
        sampled[token_id] = True
        k += 1
        if self.K is not None and k >= self.K:
            break

    logp_wc = 0
    if self.K is not None and k == self.K:
        # Get the distribution over wildcard tokens
        iter_ws = iter_logws.exp()
        W_wc = Float.chart(
            {
                token_id: iter_ws[token]
                for token_id, token in enumerate(self.target.vocab_eos)
                if not sampled[token_id]
            }
        )

        # if W_wc is non-empty, sample a wildcard token to ensure absolute continuity
        if W_wc:
            P_wc = W_wc.normalize()
            wc_id = draw(P_wc)
            logp_wc = np.log(P_wc[wc_id])
            wc = self.target.vocab_eos[wc_id]
            item_ctx = self.f(context)
            prefix_w = await self.item_potential.prefix(item_ctx)
            if wc == self.target.eos:
                w_guide_wc = await self.item_potential.complete(item_ctx) - prefix_w
            else:
                w_guide_wc = (
                    await self.item_potential.prefix(self.f(context + [wc]))
                    - prefix_w
                )
            logws[wc_id] = np.log(W_wc[wc_id]) + w_guide_wc - logp_wc

    return self.target.make_lazy_weights(logws), logp_wc

SMC

This class implements sequential Monte Carlo (SMC) inference for controlled text generation. The generation process works as follows:

  1. Token Sampling: At each step, the unit_sampler is used to extend each particle (candidate sequence) by sampling a new token. This grows all sequences by one token at a time. The sampler also outputs an importance weight with each extension to correct for the myopic nature of token-by-token sampling.

  2. Critic Evaluation: If a critic is provided, it scores the updated sequences (via it's score method), reweighting the particles based on how well they satisfy the constraints encoded by the critic.

  3. Resampling: When the effective sample size (ESS) falls below the threshold, particles are resampled according to their weights. This helps focus computation on more promising sequences.

  4. Termination: The process continues until either:

    • All sequences reach an end-of-sequence (EOS) token

    • The maximum token length is reached

If a critic is provided, the resulting sequences are properly weighted with respect to the product of the unit sampler's target potential and the critic potential (unit_sampler.target * critic). If a critic is not provided, the resulting sequences are weighted with respect to the unit sampler's target potential.

Parameters:

Name Type Description Default
unit_sampler TokenSampler

The sampler that generates tokens.

required
critic Potential

A potential function that guides the generation process by scoring candidate sequences. Must have the same token type as the unit_sampler.

None

Raises:

Type Description
ValueError

If unit_sampler is not a TokenSampler, if critic is not a Potential, or if the token types of unit_sampler and critic don't match.

Source code in genlm/control/sampler/sequence.py
class SMC:
    """This class implements sequential Monte Carlo (SMC) inference for controlled text generation.
    The generation process works as follows:

    1. Token Sampling: At each step, the `unit_sampler` is used to extend each particle (candidate sequence)
       by sampling a new token. This grows all sequences by one token at a time. The sampler also outputs
       an importance weight with each extension to correct for the myopic nature of token-by-token sampling.

    2. Critic Evaluation: If a `critic` is provided, it scores the updated sequences (via it's `score` method),
       reweighting the particles based on how well they satisfy the constraints encoded by the critic.

    3. Resampling: When the effective sample size (ESS) falls below the threshold,
       particles are resampled according to their weights. This helps focus computation
       on more promising sequences.

    4. Termination: The process continues until either:\n
        - All sequences reach an end-of-sequence (EOS) token\n
        - The maximum token length is reached

    If a critic is provided, the resulting sequences are properly weighted with respect to the product of the unit sampler's
    target potential and the critic potential (`unit_sampler.target * critic`). If a critic is not provided,
    the resulting sequences are weighted with respect to the unit sampler's target potential.

    Args:
        unit_sampler (TokenSampler): The sampler that generates tokens.
        critic (Potential, optional): A potential function that guides the generation process
            by scoring candidate sequences. Must have the same token type as the unit_sampler.

    Raises:
        ValueError: If unit_sampler is not a TokenSampler, if critic is not a Potential,
            or if the token types of unit_sampler and critic don't match.
    """

    def __init__(self, unit_sampler, critic=None):
        if not isinstance(unit_sampler, TokenSampler):
            raise ValueError("`unit_sampler` must be a TokenSampler")

        if critic:
            if not isinstance(critic, Potential):
                raise ValueError("`critic` must be a Potential")
            if not unit_sampler.token_type == critic.token_type:
                raise ValueError(
                    "`critic` must have the same token type as the `unit_sampler`. "
                    f"Got {unit_sampler.token_type} and {critic.token_type}."
                    + (
                        "\nMaybe you forgot to coerce the critic to the token type of the unit sampler? See `Coerce`."
                        if unit_sampler.token_type.is_iterable_of(critic.token_type)
                        else ""
                    )
                )

        self.unit_sampler = unit_sampler
        self.critic = critic

    async def __call__(
        self,
        n_particles,
        ess_threshold,
        max_tokens,
        verbosity=0,
        json_path=None,
        **kwargs,
    ):
        """Generate sequences using sequential Monte Carlo inference.

        Args:
            n_particles (int): Number of particles (candidate sequences) to maintain during
                generation. Higher values provide better exploration but require more
                computation.
            ess_threshold (float): Effective sample size threshold for resampling,
                expressed as a fraction of the number of particles. When ESS falls below
                this value, particles are resampled according to their weights. Should be between 0 and 1.
                Higher values lead to more frequent resampling. Note that when ess_threshold = 0,
                the critic is only applied at the end of the generation (if it is provided).
            max_tokens (int): Maximum number of tokens to generate per sequence. Generation
                may terminate earlier if all sequences reach an EOS token.
            verbosity (int, optional): Verbosity level for the SMC algorithm. 0 is silent, 1 prints the
                particles at each step. Default is 0.
            json_path (str, optional): JSON file path for saving a record of the inference run.
                This can be used in conjunction with the `InferenceVisualizer` to visualize the inference run.
            **kwargs (dict): Additional keyword arguments to pass to the SMC algorithm.
                See the `llamppl.inference.smc_standard` documentation for more details.

        Returns:
            (Sequences): A container holding the generated sequences, their importance weights, and
                other metadata from the generation process.
        """
        model = SequenceModel(
            unit_sampler=self.unit_sampler,
            critic=self.critic,
            max_tokens=max_tokens,
            verbosity=verbosity,
            twist_with_critic=ess_threshold > 0,
        )

        particles = await smc_standard(
            model=model,
            n_particles=n_particles,
            ess_threshold=ess_threshold,
            json_file=json_path,
            **kwargs,
        )

        return Sequences(*_unpack_particles(particles))

    async def cleanup(self):
        """Clean up resources used by the inference engine.

        This method should be called when the InferenceEngine is no longer needed.

        Example:
            ```python
            sampler = SequenceSampler(unit_sampler, critic)
            try:
                sequences = await sampler(n_particles=10, ess_threshold=0.5, max_tokens=20)
            finally:
                await sampler.cleanup()
            ```
        """
        await self.unit_sampler.cleanup()
        if self.critic:
            await self.critic.cleanup()

__call__(n_particles, ess_threshold, max_tokens, verbosity=0, json_path=None, **kwargs) async

Generate sequences using sequential Monte Carlo inference.

Parameters:

Name Type Description Default
n_particles int

Number of particles (candidate sequences) to maintain during generation. Higher values provide better exploration but require more computation.

required
ess_threshold float

Effective sample size threshold for resampling, expressed as a fraction of the number of particles. When ESS falls below this value, particles are resampled according to their weights. Should be between 0 and 1. Higher values lead to more frequent resampling. Note that when ess_threshold = 0, the critic is only applied at the end of the generation (if it is provided).

required
max_tokens int

Maximum number of tokens to generate per sequence. Generation may terminate earlier if all sequences reach an EOS token.

required
verbosity int

Verbosity level for the SMC algorithm. 0 is silent, 1 prints the particles at each step. Default is 0.

0
json_path str

JSON file path for saving a record of the inference run. This can be used in conjunction with the InferenceVisualizer to visualize the inference run.

None
**kwargs dict

Additional keyword arguments to pass to the SMC algorithm. See the llamppl.inference.smc_standard documentation for more details.

{}

Returns:

Type Description
Sequences

A container holding the generated sequences, their importance weights, and other metadata from the generation process.

Source code in genlm/control/sampler/sequence.py
async def __call__(
    self,
    n_particles,
    ess_threshold,
    max_tokens,
    verbosity=0,
    json_path=None,
    **kwargs,
):
    """Generate sequences using sequential Monte Carlo inference.

    Args:
        n_particles (int): Number of particles (candidate sequences) to maintain during
            generation. Higher values provide better exploration but require more
            computation.
        ess_threshold (float): Effective sample size threshold for resampling,
            expressed as a fraction of the number of particles. When ESS falls below
            this value, particles are resampled according to their weights. Should be between 0 and 1.
            Higher values lead to more frequent resampling. Note that when ess_threshold = 0,
            the critic is only applied at the end of the generation (if it is provided).
        max_tokens (int): Maximum number of tokens to generate per sequence. Generation
            may terminate earlier if all sequences reach an EOS token.
        verbosity (int, optional): Verbosity level for the SMC algorithm. 0 is silent, 1 prints the
            particles at each step. Default is 0.
        json_path (str, optional): JSON file path for saving a record of the inference run.
            This can be used in conjunction with the `InferenceVisualizer` to visualize the inference run.
        **kwargs (dict): Additional keyword arguments to pass to the SMC algorithm.
            See the `llamppl.inference.smc_standard` documentation for more details.

    Returns:
        (Sequences): A container holding the generated sequences, their importance weights, and
            other metadata from the generation process.
    """
    model = SequenceModel(
        unit_sampler=self.unit_sampler,
        critic=self.critic,
        max_tokens=max_tokens,
        verbosity=verbosity,
        twist_with_critic=ess_threshold > 0,
    )

    particles = await smc_standard(
        model=model,
        n_particles=n_particles,
        ess_threshold=ess_threshold,
        json_file=json_path,
        **kwargs,
    )

    return Sequences(*_unpack_particles(particles))

cleanup() async

Clean up resources used by the inference engine.

This method should be called when the InferenceEngine is no longer needed.

Example
sampler = SequenceSampler(unit_sampler, critic)
try:
    sequences = await sampler(n_particles=10, ess_threshold=0.5, max_tokens=20)
finally:
    await sampler.cleanup()
Source code in genlm/control/sampler/sequence.py
async def cleanup(self):
    """Clean up resources used by the inference engine.

    This method should be called when the InferenceEngine is no longer needed.

    Example:
        ```python
        sampler = SequenceSampler(unit_sampler, critic)
        try:
            sequences = await sampler(n_particles=10, ess_threshold=0.5, max_tokens=20)
        finally:
            await sampler.cleanup()
        ```
    """
    await self.unit_sampler.cleanup()
    if self.critic:
        await self.critic.cleanup()

direct_token_sampler(potential)

Create a DirectTokenSampler that samples directly from a potential's vocabulary.

See DirectTokenSampler for more details.

Parameters:

Name Type Description Default
potential Potential

The potential function to sample from. Should have an efficient logw_next method.

required

Returns:

Type Description
DirectTokenSampler

A sampler that directly samples tokens from the potential's vocabulary.

Source code in genlm/control/sampler/__init__.py
def direct_token_sampler(potential):
    """Create a `DirectTokenSampler` that samples directly from a potential's vocabulary.

    See `DirectTokenSampler` for more details.

    Args:
        potential (Potential): The potential function to sample from. Should have an efficient logw_next method.

    Returns:
        (DirectTokenSampler): A sampler that directly samples tokens from the potential's vocabulary.
    """
    assert isinstance(potential, Potential)
    return DirectTokenSampler(potential)

eager_token_sampler(iter_potential, item_potential)

Create a SetTokenSampler that uses the EagerSetSampler to sample a set of tokens.

See EagerSetSampler for more details.

Parameters:

Name Type Description Default
iter_potential Potential

A potential function defined over a vocabulary of iterables.

required
item_potential Potential

A potential function defined over a vocabulary of items which are elements of the iterables.

required

Returns:

Type Description
SetTokenSampler

A sampler that wraps an EagerSetSampler.

Note

This is the fastest sampler in most cases.

Source code in genlm/control/sampler/__init__.py
def eager_token_sampler(iter_potential, item_potential):
    """Create a `SetTokenSampler` that uses the `EagerSetSampler` to sample a set of tokens.

    See `EagerSetSampler` for more details.

    Args:
        iter_potential (Potential): A potential function defined over a vocabulary of iterables.
        item_potential (Potential): A potential function defined over a vocabulary of items which are elements of the iterables.

    Returns:
        (SetTokenSampler): A sampler that wraps an `EagerSetSampler`.

    Note:
        This is the fastest sampler in most cases.
    """
    return SetTokenSampler(EagerSetSampler(iter_potential, item_potential))

topk_token_sampler(iter_potential, item_potential, K)

Create a SetTokenSampler that uses the TopKSetSampler to sample a set of tokens.

See TopKSetSampler for more details.

Parameters:

Name Type Description Default
iter_potential Potential

A potential function defined over a vocabulary of iterables.

required
item_potential Potential

A potential function defined over a vocabulary of items which are elements of the iterables.

required
K int | None

The K parameter for the TopKSetSampler.

required

Returns:

Type Description
SetTokenSampler

A sampler that wraps an TopKSetSampler.

Source code in genlm/control/sampler/__init__.py
def topk_token_sampler(iter_potential, item_potential, K):
    """Create a `SetTokenSampler` that uses the `TopKSetSampler` to sample a set of tokens.

    See `TopKSetSampler` for more details.

    Args:
        iter_potential (Potential): A potential function defined over a vocabulary of iterables.
        item_potential (Potential): A potential function defined over a vocabulary of items which are elements of the iterables.
        K (int|None): The `K` parameter for the `TopKSetSampler`.

    Returns:
        (SetTokenSampler): A sampler that wraps an `TopKSetSampler`.
    """
    return SetTokenSampler(TopKSetSampler(iter_potential, item_potential, K))