Skip to content

potential

Potential

Bases: ABC, PotentialOps, PotentialTests

Abstract base class for potentials.

A Potential is a function that maps sequences of tokens in a vocabulary to non-negative real numbers (weights).

Potentials assign weights to sequences of tokens based on whether they are complete sequences or prefixes of complete sequences.

  • complete: Assess the log weight of a sequence of tokens in the vocabulary as a complete sequence.
  • prefix: Assess the log weight of a sequence of tokens in the vocabulary as a prefix.

Potentials additionally implement a logw_next method:

  • logw_next: Compute the next-token log weights of each token in the vocabulary and a special EOS (end-of-sequence) token given a context.

Subclasses must minimally implement complete and prefix. logw_next and batched versions of the above methods come with default implementations, but may be overridden by subclasses for improved performance.

All Potentials must satisfy a set of properties which can be tested using PotentialTests.

Attributes:

Name Type Description
token_type TokenType

The type of tokens in the vocabulary.

vocab list

List of tokens making up the vocabulary.

eos EndOfSequence

Special token to use as end-of-sequence.

vocab_eos list

List of tokens in vocab and eos. eos is assumed to be the last token in vocab_eos.

lookup dict

Mapping from tokens and eos to their indices in vocab_eos.

Source code in genlm/control/potential/base.py
class Potential(ABC, PotentialOps, PotentialTests):
    """Abstract base class for potentials.

    A Potential is a function that maps sequences of tokens in a vocabulary to non-negative real numbers (weights).

    Potentials assign weights to sequences of tokens based on whether they are complete sequences or prefixes of complete sequences.

    - `complete`: Assess the log weight of a sequence of tokens in the vocabulary as a complete sequence.
    - `prefix`: Assess the log weight of a sequence of tokens in the vocabulary as a prefix.

    Potentials additionally implement a `logw_next` method:

    - `logw_next`: Compute the next-token log weights of each token in the vocabulary and a special EOS (end-of-sequence) token given a context.

    Subclasses must minimally implement `complete` and `prefix`. `logw_next` and batched versions of the above methods
    come with default implementations, but may be overridden by subclasses for improved performance.

    All Potentials must satisfy a set of properties which can be tested using [PotentialTests][genlm.control.potential.testing.PotentialTests].

    Attributes:
        token_type (TokenType): The type of tokens in the vocabulary.
        vocab (list): List of tokens making up the vocabulary.
        eos (EndOfSequence): Special token to use as end-of-sequence.
        vocab_eos (list): List of tokens in `vocab` and `eos`. `eos` is assumed to be the last token in `vocab_eos`.
        lookup (dict): Mapping from tokens and `eos` to their indices in `vocab_eos`.
    """

    def __init__(self, vocabulary, token_type=None, eos=None):
        """
        Initialize the potential.

        Args:
            vocabulary (list): List of tokens that make up the vocabulary.
            token_type (TokenType, optional): Optional TokenType of all elements of the vocabulary.
                If None, will be inferred from vocabulary.
            eos (EndOfSequence, optional): Special token to use as end-of-sequence. Defaults to `EOS` sentinel.

        Raises:
            ValueError: If vocabulary is empty.
            TypeError: If vocabulary contains tokens which are not of `token_type`.
        """
        if not vocabulary:
            raise ValueError("vocabulary cannot be empty")

        if token_type is None:
            token_type = infer_vocabulary_type(vocabulary)
        elif not isinstance(token_type, TokenType):
            raise ValueError(f"token_type must be a TokenType, got {token_type!r}.")

        if not all(token_type.check(x) for x in vocabulary):
            raise TypeError(f"Tokens in vocabulary must be of type {token_type}.")

        if eos is not None and not isinstance(eos, EndOfSequence):
            raise ValueError("EOS must be an instance of EndOfSequence")

        self.eos = eos if eos is not None else EOS

        self.token_type = token_type
        self.vocab = vocabulary
        self.vocab_eos = self.vocab + [self.eos]
        self.lookup = {}
        for i, x in enumerate(vocabulary):
            if x in self.lookup:
                raise ValueError(f"Duplicate token {x!r} found in vocabulary")
            self.lookup[x] = i
        self.lookup[self.eos] = len(self.vocab)

    ####################
    # Instance methods #
    ####################

    @abstractmethod
    async def complete(self, context) -> float:
        """Assess the weight of `context` as a complete sequence.

        Args:
            context (list): Sequence of tokens.

        Returns:
            (float): Log weight of the context under the language.
        """
        return 0.0  # pragma: no cover

    @abstractmethod
    async def prefix(self, context) -> float:
        """Assess the weight of `context` as a prefix.

        Args:
            context (list): Sequence of tokens.

        Returns:
            (float): Log weight of the context as a prefix.
        """
        return 0.0  # pragma: no cover

    async def score(self, context):
        """Assess the weight of `context` based on EOS-termination.

        This is a convenience method which dispatches to `complete` if `context` ends with `self.eos`, otherwise to `prefix`.

        Args:
            context (list): Sequence of tokens.

        Returns:
            (float): Log weight of the context, either as a prefix or complete sequence.
        """
        if context and context[-1] == self.eos:
            return await self.complete(context[:-1])
        else:
            return await self.prefix(context)

    async def logw_next(self, context):
        """Compute the next-token weights of each token in `self.vocab_eos` given `context`.

        Args:
            context (list): Sequence of tokens.

        Returns:
            (LazyWeights): Weights of each token in the vocabulary and EOS.
        """
        ctx_log_w = await self.prefix(context)

        if ctx_log_w == float("-inf"):
            raise ValueError(f"Context {context!r} has weight zero under `prefix`.")

        scores = await self.batch_score([[*context, x] for x in self.vocab_eos])
        logws = scores - ctx_log_w

        return self.make_lazy_weights(logws)

    ###################
    # Batched methods #
    ###################

    async def batch_complete(self, contexts):
        """Batched equivalent to `complete`.

        Assess the weight of each context as a complete sequence.

        Args:
            contexts (list): List of sequences of tokens.

        Returns:
            (np.array): Array of log weights for each context.
        """
        if not contexts:
            raise ValueError("Contexts must be non-empty.")

        return np.array(
            await asyncio.gather(*[self.complete(context) for context in contexts])
        )

    async def batch_prefix(self, contexts):
        """Batched equivalent to `prefix`.

        Assess the weight of each context as a prefix.

        Args:
            contexts (list): List of sequences of tokens.

        Returns:
            (np.array): Array of log weights for each context.
        """
        if not contexts:
            raise ValueError("Contexts must be non-empty.")

        return np.array(
            await asyncio.gather(*[self.prefix(context) for context in contexts])
        )

    async def batch_score(self, contexts):
        """Batched equivalent to `score`.

        Assess the weight of each context based on EOS-termination.

        Args:
            contexts (list): List of sequences of tokens.

        Returns:
            (np.array): Array of log weights for each context.
        """
        if not contexts:
            raise ValueError("Contexts must be non-empty.")

        complete, prefix = [], []
        complete_indices, prefix_indices = [], []

        for i, context in enumerate(contexts):
            # We want == here instead of `is`.
            if context and context[-1] == self.eos:
                complete.append(context[:-1])
                complete_indices.append(i)
            else:
                prefix.append(context)
                prefix_indices.append(i)

        complete_scores = (
            await self.batch_complete(complete) if complete else np.array([])
        )
        prefix_scores = await self.batch_prefix(prefix) if prefix else np.array([])

        results = np.empty(len(contexts))
        if len(complete_scores) > 0:
            results[complete_indices] = complete_scores
        if len(prefix_scores) > 0:
            results[prefix_indices] = prefix_scores

        return results

    async def batch_logw_next(self, contexts):
        """Batched equivalent to `logw_next`.

        Computes the next-token weights of each token in `self.vocab_eos` given each context in the batch.

        Args:
            contexts (list): List of sequences of tokens.

        Returns:
            (list): List of LazyWeights objects, one for each context.

        Raises:
            ValueError: If any context has zero weight (log weight of -inf) under `prefix`.
        """
        if not contexts:
            raise ValueError("Contexts must be non-empty.")

        return await asyncio.gather(*[self.logw_next(context) for context in contexts])

    #############
    # Utilities #
    #############

    def make_lazy_weights(self, weights, log=True):
        """Helper method to create a LazyWeights object over the potential's vocabulary and EOS.

        Args:
            weights (np.array): Array of weights.
            log (bool, optional): Whether the weights are in log space. Defaults to True.

        Returns:
            (LazyWeights): LazyWeights object defined over `self.vocab_eos`.
        """
        return LazyWeights(
            weights=weights, encode=self.lookup, decode=self.vocab_eos, log=log
        )

    def alloc_logws(self, default=float("-inf")):
        """Allocate a new array of log weights for the potential's vocabulary and EOS.

        Args:
            default (float, optional): Default log weight. Defaults to -inf.

        Returns:
            (np.array): Array of length `len(self.vocab_eos)` filled with `default`.
        """
        return np.full((len(self.vocab_eos),), default)

    def spawn(self):
        """
        Spawn a fresh instance of the potential.

        This method is not required by default, but may be implemented by subclasses
        to support CPU-parallelism using (`MultiProcPotential`)[genlm.control.potential.multi_proc.MultiProcPotential].
        """
        raise NotImplementedError(
            "Potential.spawn() must be implemented by subclasses."
        )

    async def cleanup(self):
        """
        Cleanup the potential.

        This method may be implemented by subclasses to release resources.
        """
        pass

__init__(vocabulary, token_type=None, eos=None)

Initialize the potential.

Parameters:

Name Type Description Default
vocabulary list

List of tokens that make up the vocabulary.

required
token_type TokenType

Optional TokenType of all elements of the vocabulary. If None, will be inferred from vocabulary.

None
eos EndOfSequence

Special token to use as end-of-sequence. Defaults to EOS sentinel.

None

Raises:

Type Description
ValueError

If vocabulary is empty.

TypeError

If vocabulary contains tokens which are not of token_type.

Source code in genlm/control/potential/base.py
def __init__(self, vocabulary, token_type=None, eos=None):
    """
    Initialize the potential.

    Args:
        vocabulary (list): List of tokens that make up the vocabulary.
        token_type (TokenType, optional): Optional TokenType of all elements of the vocabulary.
            If None, will be inferred from vocabulary.
        eos (EndOfSequence, optional): Special token to use as end-of-sequence. Defaults to `EOS` sentinel.

    Raises:
        ValueError: If vocabulary is empty.
        TypeError: If vocabulary contains tokens which are not of `token_type`.
    """
    if not vocabulary:
        raise ValueError("vocabulary cannot be empty")

    if token_type is None:
        token_type = infer_vocabulary_type(vocabulary)
    elif not isinstance(token_type, TokenType):
        raise ValueError(f"token_type must be a TokenType, got {token_type!r}.")

    if not all(token_type.check(x) for x in vocabulary):
        raise TypeError(f"Tokens in vocabulary must be of type {token_type}.")

    if eos is not None and not isinstance(eos, EndOfSequence):
        raise ValueError("EOS must be an instance of EndOfSequence")

    self.eos = eos if eos is not None else EOS

    self.token_type = token_type
    self.vocab = vocabulary
    self.vocab_eos = self.vocab + [self.eos]
    self.lookup = {}
    for i, x in enumerate(vocabulary):
        if x in self.lookup:
            raise ValueError(f"Duplicate token {x!r} found in vocabulary")
        self.lookup[x] = i
    self.lookup[self.eos] = len(self.vocab)

complete(context) abstractmethod async

Assess the weight of context as a complete sequence.

Parameters:

Name Type Description Default
context list

Sequence of tokens.

required

Returns:

Type Description
float

Log weight of the context under the language.

Source code in genlm/control/potential/base.py
@abstractmethod
async def complete(self, context) -> float:
    """Assess the weight of `context` as a complete sequence.

    Args:
        context (list): Sequence of tokens.

    Returns:
        (float): Log weight of the context under the language.
    """
    return 0.0  # pragma: no cover

prefix(context) abstractmethod async

Assess the weight of context as a prefix.

Parameters:

Name Type Description Default
context list

Sequence of tokens.

required

Returns:

Type Description
float

Log weight of the context as a prefix.

Source code in genlm/control/potential/base.py
@abstractmethod
async def prefix(self, context) -> float:
    """Assess the weight of `context` as a prefix.

    Args:
        context (list): Sequence of tokens.

    Returns:
        (float): Log weight of the context as a prefix.
    """
    return 0.0  # pragma: no cover

score(context) async

Assess the weight of context based on EOS-termination.

This is a convenience method which dispatches to complete if context ends with self.eos, otherwise to prefix.

Parameters:

Name Type Description Default
context list

Sequence of tokens.

required

Returns:

Type Description
float

Log weight of the context, either as a prefix or complete sequence.

Source code in genlm/control/potential/base.py
async def score(self, context):
    """Assess the weight of `context` based on EOS-termination.

    This is a convenience method which dispatches to `complete` if `context` ends with `self.eos`, otherwise to `prefix`.

    Args:
        context (list): Sequence of tokens.

    Returns:
        (float): Log weight of the context, either as a prefix or complete sequence.
    """
    if context and context[-1] == self.eos:
        return await self.complete(context[:-1])
    else:
        return await self.prefix(context)

logw_next(context) async

Compute the next-token weights of each token in self.vocab_eos given context.

Parameters:

Name Type Description Default
context list

Sequence of tokens.

required

Returns:

Type Description
LazyWeights

Weights of each token in the vocabulary and EOS.

Source code in genlm/control/potential/base.py
async def logw_next(self, context):
    """Compute the next-token weights of each token in `self.vocab_eos` given `context`.

    Args:
        context (list): Sequence of tokens.

    Returns:
        (LazyWeights): Weights of each token in the vocabulary and EOS.
    """
    ctx_log_w = await self.prefix(context)

    if ctx_log_w == float("-inf"):
        raise ValueError(f"Context {context!r} has weight zero under `prefix`.")

    scores = await self.batch_score([[*context, x] for x in self.vocab_eos])
    logws = scores - ctx_log_w

    return self.make_lazy_weights(logws)

batch_complete(contexts) async

Batched equivalent to complete.

Assess the weight of each context as a complete sequence.

Parameters:

Name Type Description Default
contexts list

List of sequences of tokens.

required

Returns:

Type Description
array

Array of log weights for each context.

Source code in genlm/control/potential/base.py
async def batch_complete(self, contexts):
    """Batched equivalent to `complete`.

    Assess the weight of each context as a complete sequence.

    Args:
        contexts (list): List of sequences of tokens.

    Returns:
        (np.array): Array of log weights for each context.
    """
    if not contexts:
        raise ValueError("Contexts must be non-empty.")

    return np.array(
        await asyncio.gather(*[self.complete(context) for context in contexts])
    )

batch_prefix(contexts) async

Batched equivalent to prefix.

Assess the weight of each context as a prefix.

Parameters:

Name Type Description Default
contexts list

List of sequences of tokens.

required

Returns:

Type Description
array

Array of log weights for each context.

Source code in genlm/control/potential/base.py
async def batch_prefix(self, contexts):
    """Batched equivalent to `prefix`.

    Assess the weight of each context as a prefix.

    Args:
        contexts (list): List of sequences of tokens.

    Returns:
        (np.array): Array of log weights for each context.
    """
    if not contexts:
        raise ValueError("Contexts must be non-empty.")

    return np.array(
        await asyncio.gather(*[self.prefix(context) for context in contexts])
    )

batch_score(contexts) async

Batched equivalent to score.

Assess the weight of each context based on EOS-termination.

Parameters:

Name Type Description Default
contexts list

List of sequences of tokens.

required

Returns:

Type Description
array

Array of log weights for each context.

Source code in genlm/control/potential/base.py
async def batch_score(self, contexts):
    """Batched equivalent to `score`.

    Assess the weight of each context based on EOS-termination.

    Args:
        contexts (list): List of sequences of tokens.

    Returns:
        (np.array): Array of log weights for each context.
    """
    if not contexts:
        raise ValueError("Contexts must be non-empty.")

    complete, prefix = [], []
    complete_indices, prefix_indices = [], []

    for i, context in enumerate(contexts):
        # We want == here instead of `is`.
        if context and context[-1] == self.eos:
            complete.append(context[:-1])
            complete_indices.append(i)
        else:
            prefix.append(context)
            prefix_indices.append(i)

    complete_scores = (
        await self.batch_complete(complete) if complete else np.array([])
    )
    prefix_scores = await self.batch_prefix(prefix) if prefix else np.array([])

    results = np.empty(len(contexts))
    if len(complete_scores) > 0:
        results[complete_indices] = complete_scores
    if len(prefix_scores) > 0:
        results[prefix_indices] = prefix_scores

    return results

batch_logw_next(contexts) async

Batched equivalent to logw_next.

Computes the next-token weights of each token in self.vocab_eos given each context in the batch.

Parameters:

Name Type Description Default
contexts list

List of sequences of tokens.

required

Returns:

Type Description
list

List of LazyWeights objects, one for each context.

Raises:

Type Description
ValueError

If any context has zero weight (log weight of -inf) under prefix.

Source code in genlm/control/potential/base.py
async def batch_logw_next(self, contexts):
    """Batched equivalent to `logw_next`.

    Computes the next-token weights of each token in `self.vocab_eos` given each context in the batch.

    Args:
        contexts (list): List of sequences of tokens.

    Returns:
        (list): List of LazyWeights objects, one for each context.

    Raises:
        ValueError: If any context has zero weight (log weight of -inf) under `prefix`.
    """
    if not contexts:
        raise ValueError("Contexts must be non-empty.")

    return await asyncio.gather(*[self.logw_next(context) for context in contexts])

make_lazy_weights(weights, log=True)

Helper method to create a LazyWeights object over the potential's vocabulary and EOS.

Parameters:

Name Type Description Default
weights array

Array of weights.

required
log bool

Whether the weights are in log space. Defaults to True.

True

Returns:

Type Description
LazyWeights

LazyWeights object defined over self.vocab_eos.

Source code in genlm/control/potential/base.py
def make_lazy_weights(self, weights, log=True):
    """Helper method to create a LazyWeights object over the potential's vocabulary and EOS.

    Args:
        weights (np.array): Array of weights.
        log (bool, optional): Whether the weights are in log space. Defaults to True.

    Returns:
        (LazyWeights): LazyWeights object defined over `self.vocab_eos`.
    """
    return LazyWeights(
        weights=weights, encode=self.lookup, decode=self.vocab_eos, log=log
    )

alloc_logws(default=float('-inf'))

Allocate a new array of log weights for the potential's vocabulary and EOS.

Parameters:

Name Type Description Default
default float

Default log weight. Defaults to -inf.

float('-inf')

Returns:

Type Description
array

Array of length len(self.vocab_eos) filled with default.

Source code in genlm/control/potential/base.py
def alloc_logws(self, default=float("-inf")):
    """Allocate a new array of log weights for the potential's vocabulary and EOS.

    Args:
        default (float, optional): Default log weight. Defaults to -inf.

    Returns:
        (np.array): Array of length `len(self.vocab_eos)` filled with `default`.
    """
    return np.full((len(self.vocab_eos),), default)

spawn()

Spawn a fresh instance of the potential.

This method is not required by default, but may be implemented by subclasses to support CPU-parallelism using (MultiProcPotential)[genlm.control.potential.multi_proc.MultiProcPotential].

Source code in genlm/control/potential/base.py
def spawn(self):
    """
    Spawn a fresh instance of the potential.

    This method is not required by default, but may be implemented by subclasses
    to support CPU-parallelism using (`MultiProcPotential`)[genlm.control.potential.multi_proc.MultiProcPotential].
    """
    raise NotImplementedError(
        "Potential.spawn() must be implemented by subclasses."
    )

cleanup() async

Cleanup the potential.

This method may be implemented by subclasses to release resources.

Source code in genlm/control/potential/base.py
async def cleanup(self):
    """
    Cleanup the potential.

    This method may be implemented by subclasses to release resources.
    """
    pass

AutoBatchedPotential

Bases: Potential

AutoBatchedPotential is a wrapper around a Potential that enables automatic batching of concurrent requests.

This class manages a background loop that collects concurrent requests to instance methods (complete, prefix, score, logw_next) and batches them together before delegating to the corresponding batch methods of the underlying potential (batch_complete, batch_prefix, batch_score, batch_logw_next).

This class inherits all methods from Potential.

Attributes:

Name Type Description
potential Potential

The underlying potential instance that is being wrapped.

background_loop AsyncBatchLoop

An asynchronous loop that manages batch requests.

Source code in genlm/control/potential/autobatch.py
class AutoBatchedPotential(Potential):
    """
    AutoBatchedPotential is a wrapper around a Potential that enables automatic batching of concurrent requests.

    This class manages a background loop that collects concurrent requests to instance methods
    (`complete`, `prefix`, `score`, `logw_next`) and batches them together before
    delegating to the corresponding batch methods of the underlying potential
    (`batch_complete`, `batch_prefix`, `batch_score`, `batch_logw_next`).

    This class inherits all methods from [`Potential`][genlm.control.potential.base.Potential].

    Attributes:
        potential (Potential): The underlying potential instance that is being wrapped.
        background_loop (AsyncBatchLoop): An asynchronous loop that manages batch requests.
    """

    def __init__(self, potential):
        self.potential = potential
        self.background_loop = AsyncBatchLoop(potential)
        self.background_loop.start()
        super().__init__(potential.vocab)

    async def complete(self, context):
        return await self.background_loop.queue_request(
            "batch_complete", lambda args: ([*args[0], context],)
        )

    async def prefix(self, context):
        return await self.background_loop.queue_request(
            "batch_prefix", lambda args: ([*args[0], context],)
        )

    async def score(self, context):
        return await self.background_loop.queue_request(
            "batch_score", lambda args: ([*args[0], context],)
        )

    async def logw_next(self, context):
        return await self.background_loop.queue_request(
            "batch_logw_next", lambda args: ([*args[0], context],)
        )

    async def batch_complete(self, contexts):
        return await self.potential.batch_complete(contexts)

    async def batch_prefix(self, contexts):
        return await self.potential.batch_prefix(contexts)

    async def batch_score(self, contexts):
        return await self.potential.batch_score(contexts)

    async def batch_logw_next(self, contexts):
        return await self.potential.batch_logw_next(contexts)

    def spawn(self, *args, **kwargs):
        # creates a new background loop.
        return AutoBatchedPotential(self.potential.spawn(*args, **kwargs))

    def __repr__(self):
        return f"{self.__class__.__name__}({self.potential!r})"

    async def cleanup(self):
        """Async cleanup - preferred method"""
        await self.background_loop.cleanup()

    def __del__(self):
        if loop := getattr(self, "background_loop", None):
            loop.close()

cleanup() async

Async cleanup - preferred method

Source code in genlm/control/potential/autobatch.py
async def cleanup(self):
    """Async cleanup - preferred method"""
    await self.background_loop.cleanup()

MultiProcPotential

Bases: Potential

A Potential that adds parallel processing capabilities to any base Potential implementation.

Creates a process pool of worker processes, each containing an instance of the potential.

This class inherits all methods from Potential. Each method delegates to a corresponding method of the potential instances running in the worker processes, distributing work across multiple processes for improved performance.

Source code in genlm/control/potential/multi_proc.py
class MultiProcPotential(Potential):
    """A Potential that adds parallel processing capabilities to any base Potential implementation.

    Creates a process pool of worker processes, each containing an instance of the potential.

    This class inherits all methods from [`Potential`][genlm.control.potential.base.Potential].
    Each method delegates to a corresponding method of the potential instances running in the
    worker processes, distributing work across multiple processes for improved performance.
    """

    def __init__(self, potential_factory, factory_args, num_workers=2):
        """
        Initialize the MultiProcPotential.

        Args:
            potential_factory (callable): A factory function that creates a potential instance.
            factory_args (tuple): Arguments to pass to the potential factory.
            num_workers (int): The number of worker processes to spawn. Each will contain an instance of the potential.
        """
        self.num_workers = num_workers
        self.executor = ProcessPoolExecutor(
            max_workers=num_workers,
            initializer=self._init_worker,
            initargs=(potential_factory, factory_args),
        )
        # Get vocab and eos from one of the workers
        vocab, eos = self.executor.submit(self._get_vocab_and_eos).result()
        super().__init__(vocab, eos=eos)

    @staticmethod
    def _init_worker(factory, args):
        global _worker_potential, _worker_event_loop
        _worker_potential = factory(*args)
        _worker_event_loop = asyncio.new_event_loop()
        asyncio.set_event_loop(_worker_event_loop)

    @staticmethod
    def _get_vocab_and_eos():
        return _worker_potential.vocab, _worker_potential.eos

    @staticmethod
    def _run_coroutine(coroutine):
        global _worker_event_loop
        return _worker_event_loop.run_until_complete(coroutine)

    @staticmethod
    def _worker_logw_next(context):
        return MultiProcPotential._run_coroutine(
            _worker_potential.logw_next(context)
        ).weights

    @staticmethod
    def _worker_prefix(context):
        return MultiProcPotential._run_coroutine(_worker_potential.prefix(context))

    @staticmethod
    def _worker_complete(context):
        return MultiProcPotential._run_coroutine(_worker_potential.complete(context))

    # @staticmethod
    # def _worker_score(context):
    #    return MultiProcPotential._run_coroutine(_worker_potential.score(context))

    async def _run_in_executor(self, func, *args):
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(self.executor, func, *args)

    async def logw_next(self, context):
        result = await self._run_in_executor(self._worker_logw_next, context)
        return self.make_lazy_weights(result)

    async def prefix(self, context):
        return await self._run_in_executor(self._worker_prefix, context)

    async def complete(self, context):
        return await self._run_in_executor(self._worker_complete, context)

    async def batch_logw_next(self, contexts):
        results = await asyncio.gather(
            *(
                self._run_in_executor(self._worker_logw_next, context)
                for context in contexts
            )
        )
        return [self.make_lazy_weights(result) for result in results]

    async def batch_complete(self, contexts):
        results = await asyncio.gather(
            *(
                self._run_in_executor(self._worker_complete, context)
                for context in contexts
            )
        )
        return np.array(results)

    async def batch_prefix(self, contexts):
        results = await asyncio.gather(
            *(
                self._run_in_executor(self._worker_prefix, context)
                for context in contexts
            )
        )
        return np.array(results)

    def __del__(self):
        if self.executor is not None:
            self.executor.shutdown()
            self.executor = None

    def __repr__(self):
        return f"{self.__class__.__name__}({self.num_workers=})"

    def spawn(self):
        raise ValueError("MultiProcPotentials are not spawnable.")

__init__(potential_factory, factory_args, num_workers=2)

Initialize the MultiProcPotential.

Parameters:

Name Type Description Default
potential_factory callable

A factory function that creates a potential instance.

required
factory_args tuple

Arguments to pass to the potential factory.

required
num_workers int

The number of worker processes to spawn. Each will contain an instance of the potential.

2
Source code in genlm/control/potential/multi_proc.py
def __init__(self, potential_factory, factory_args, num_workers=2):
    """
    Initialize the MultiProcPotential.

    Args:
        potential_factory (callable): A factory function that creates a potential instance.
        factory_args (tuple): Arguments to pass to the potential factory.
        num_workers (int): The number of worker processes to spawn. Each will contain an instance of the potential.
    """
    self.num_workers = num_workers
    self.executor = ProcessPoolExecutor(
        max_workers=num_workers,
        initializer=self._init_worker,
        initargs=(potential_factory, factory_args),
    )
    # Get vocab and eos from one of the workers
    vocab, eos = self.executor.submit(self._get_vocab_and_eos).result()
    super().__init__(vocab, eos=eos)

PotentialOps

Mixin providing operations for potential functions:

  1. Product (*): Take the product of two potentials.

  2. Coercion (coerce): Coerce the potential to operate on another potential's vocabulary.

  3. Auto-batching (to_autobatched): Create a version that automatically batches concurrent requests to the instance methods.

  4. Parallelization (to_multiprocess): Create a version that parallelizes operations over multiple processes.

Source code in genlm/control/potential/operators.py
class PotentialOps:
    """Mixin providing operations for potential functions:

    1. Product (`*`): Take the product of two potentials.\n
    2. Coercion (`coerce`): Coerce the potential to operate on another potential's vocabulary.\n
    3. Auto-batching (`to_autobatched`): Create a version that automatically batches concurrent requests to the instance methods.\n
    4. Parallelization (`to_multiprocess`): Create a version that parallelizes operations over multiple processes.\n
    """

    def __mul__(self, other):
        """Take the product of two potentials.

        See [`Product`][genlm.control.potential.product.Product] for more details.

        Args:
            other (Potential): Another potential instance to take the product with.

        Returns:
            (Product): A Product instance representing the unnormalized product of the two potentials.

        Note:
            Potentials must operate on the same token type and the intersection of their vocabularies must be non-empty.
        """
        from genlm.control.potential.product import Product

        return Product(self, other)

    def coerce(self, other, f, prune=True):
        """Coerce the current potential to operate on the vocabulary of another potential.

        See [`Coerced`][genlm.control.potential.coerce.Coerced] for more details.

        Args:
            other (Potential): The potential instance whose vocabulary will be used.
            f (callable): A function mapping sequences of tokens from self's vocab to sequences of tokens from other's vocab.
            prune (bool): Whether to prune the coerced potential's vocabulary to only include tokens that can be mapped to the original potential's vocabulary.
                If `False`, the coerced potential's vocabulary will include all tokens from the target vocabulary.

        Returns:
            (Coerced): A Potential that operates on the vocabulary of `other`.
        """
        from genlm.control.potential.coerce import Coerced

        return Coerced(self, other.vocab, f=f, prune=prune)

    def to_autobatched(self):
        """Create a new potential instance that automatically batches concurrent requests to the instance methods.

        See [`AutoBatchedPotential`][genlm.control.potential.autobatch.AutoBatchedPotential] for more details.

        Returns:
            (AutoBatchedPotential): A new potential instance that wraps the current potential and automatically batches concurrent requests to the instance methods.
        """
        from genlm.control.potential.autobatch import AutoBatchedPotential

        return AutoBatchedPotential(self)

    def to_multiprocess(self, num_workers=2, spawn_args=None):
        """Create a new potential instance that parallelizes operations using multiprocessing.

        See [`MultiProcPotential`][genlm.control.potential.multi_proc.MultiProcPotential] for more details.

        Args:
            num_workers (int): The number of workers to use in the multiprocessing pool.
            spawn_args (tuple): The positional arguments to pass to the potential's `spawn` method.

        Returns:
            (MultiProcPotential): A new potential instance that wraps the current potential and uses multiprocessing to parallelize operations.

        Note:
            For this method to be used, the potential must implement a picklable `spawn` method.
        """
        from genlm.control.potential.multi_proc import MultiProcPotential

        factory_args = spawn_args or ()
        return MultiProcPotential(
            potential_factory=self.spawn,
            factory_args=factory_args,
            num_workers=num_workers,
        )

__mul__(other)

Take the product of two potentials.

See Product for more details.

Parameters:

Name Type Description Default
other Potential

Another potential instance to take the product with.

required

Returns:

Type Description
Product

A Product instance representing the unnormalized product of the two potentials.

Note

Potentials must operate on the same token type and the intersection of their vocabularies must be non-empty.

Source code in genlm/control/potential/operators.py
def __mul__(self, other):
    """Take the product of two potentials.

    See [`Product`][genlm.control.potential.product.Product] for more details.

    Args:
        other (Potential): Another potential instance to take the product with.

    Returns:
        (Product): A Product instance representing the unnormalized product of the two potentials.

    Note:
        Potentials must operate on the same token type and the intersection of their vocabularies must be non-empty.
    """
    from genlm.control.potential.product import Product

    return Product(self, other)

coerce(other, f, prune=True)

Coerce the current potential to operate on the vocabulary of another potential.

See Coerced for more details.

Parameters:

Name Type Description Default
other Potential

The potential instance whose vocabulary will be used.

required
f callable

A function mapping sequences of tokens from self's vocab to sequences of tokens from other's vocab.

required
prune bool

Whether to prune the coerced potential's vocabulary to only include tokens that can be mapped to the original potential's vocabulary. If False, the coerced potential's vocabulary will include all tokens from the target vocabulary.

True

Returns:

Type Description
Coerced

A Potential that operates on the vocabulary of other.

Source code in genlm/control/potential/operators.py
def coerce(self, other, f, prune=True):
    """Coerce the current potential to operate on the vocabulary of another potential.

    See [`Coerced`][genlm.control.potential.coerce.Coerced] for more details.

    Args:
        other (Potential): The potential instance whose vocabulary will be used.
        f (callable): A function mapping sequences of tokens from self's vocab to sequences of tokens from other's vocab.
        prune (bool): Whether to prune the coerced potential's vocabulary to only include tokens that can be mapped to the original potential's vocabulary.
            If `False`, the coerced potential's vocabulary will include all tokens from the target vocabulary.

    Returns:
        (Coerced): A Potential that operates on the vocabulary of `other`.
    """
    from genlm.control.potential.coerce import Coerced

    return Coerced(self, other.vocab, f=f, prune=prune)

to_autobatched()

Create a new potential instance that automatically batches concurrent requests to the instance methods.

See AutoBatchedPotential for more details.

Returns:

Type Description
AutoBatchedPotential

A new potential instance that wraps the current potential and automatically batches concurrent requests to the instance methods.

Source code in genlm/control/potential/operators.py
def to_autobatched(self):
    """Create a new potential instance that automatically batches concurrent requests to the instance methods.

    See [`AutoBatchedPotential`][genlm.control.potential.autobatch.AutoBatchedPotential] for more details.

    Returns:
        (AutoBatchedPotential): A new potential instance that wraps the current potential and automatically batches concurrent requests to the instance methods.
    """
    from genlm.control.potential.autobatch import AutoBatchedPotential

    return AutoBatchedPotential(self)

to_multiprocess(num_workers=2, spawn_args=None)

Create a new potential instance that parallelizes operations using multiprocessing.

See MultiProcPotential for more details.

Parameters:

Name Type Description Default
num_workers int

The number of workers to use in the multiprocessing pool.

2
spawn_args tuple

The positional arguments to pass to the potential's spawn method.

None

Returns:

Type Description
MultiProcPotential

A new potential instance that wraps the current potential and uses multiprocessing to parallelize operations.

Note

For this method to be used, the potential must implement a picklable spawn method.

Source code in genlm/control/potential/operators.py
def to_multiprocess(self, num_workers=2, spawn_args=None):
    """Create a new potential instance that parallelizes operations using multiprocessing.

    See [`MultiProcPotential`][genlm.control.potential.multi_proc.MultiProcPotential] for more details.

    Args:
        num_workers (int): The number of workers to use in the multiprocessing pool.
        spawn_args (tuple): The positional arguments to pass to the potential's `spawn` method.

    Returns:
        (MultiProcPotential): A new potential instance that wraps the current potential and uses multiprocessing to parallelize operations.

    Note:
        For this method to be used, the potential must implement a picklable `spawn` method.
    """
    from genlm.control.potential.multi_proc import MultiProcPotential

    factory_args = spawn_args or ()
    return MultiProcPotential(
        potential_factory=self.spawn,
        factory_args=factory_args,
        num_workers=num_workers,
    )

Product

Bases: Potential

Combine two potential instances via element-wise multiplication (sum in log space).

This class creates a new potential that is the element-wise product of two potentials:

prefix(xs) = p1.prefix(xs) + p2.prefix(xs)
complete(xs) = p1.complete(xs) + p2.complete(xs)
logw_next(x | xs) = p1.logw_next(x | xs) + p2.logw_next(x | xs)

The new potential's vocabulary is the intersection of the two potentials' vocabularies.

This class inherits all methods from Potential, see there for method documentation.

Attributes:

Name Type Description
p1 Potential

The first potential instance.

p2 Potential

The second potential instance.

token_type str

The type of tokens that this product potential operates on.

vocab list

The common vocabulary shared between the two potentials.

Warning

Be careful when taking products of potentials with minimal vocabulary overlap. The resulting potential will only operate on tokens present in both vocabularies.

Source code in genlm/control/potential/product.py
class Product(Potential):
    """
    Combine two potential instances via element-wise multiplication (sum in log space).

    This class creates a new potential that is the element-wise product of two potentials:
    ```
    prefix(xs) = p1.prefix(xs) + p2.prefix(xs)
    complete(xs) = p1.complete(xs) + p2.complete(xs)
    logw_next(x | xs) = p1.logw_next(x | xs) + p2.logw_next(x | xs)
    ```

    The new potential's vocabulary is the intersection of the two potentials' vocabularies.

    This class inherits all methods from [`Potential`][genlm.control.potential.base.Potential],
    see there for method documentation.

    Attributes:
        p1 (Potential): The first potential instance.
        p2 (Potential): The second potential instance.
        token_type (str): The type of tokens that this product potential operates on.
        vocab (list): The common vocabulary shared between the two potentials.

    Warning:
        Be careful when taking products of potentials with minimal vocabulary overlap.
        The resulting potential will only operate on tokens present in both vocabularies.
    """

    def __init__(self, p1, p2):
        """Initialize a Product potential.

        Args:
            p1 (Potential): First potential
            p2 (Potential): Second potential
        """
        self.p1 = p1
        self.p2 = p2

        if self.p1.token_type == self.p2.token_type:
            token_type = self.p1.token_type
        else:
            raise ValueError(
                "Potentials in product must have the same token type. "
                f"Got {self.p1.token_type} and {self.p2.token_type}."
                + (
                    "\nMaybe you forgot to coerce the potentials to the same token type? See `Coerce`."
                    if (
                        self.p1.token_type.is_iterable_of(self.p2.token_type)
                        or self.p2.token_type.is_iterable_of(self.p1.token_type)
                    )
                    else ""
                )
            )

        if self.p1.vocab == self.p2.vocab:
            self._v1_idxs = ...
            self._v2_idxs = ...
            super().__init__(self.p1.vocab, token_type=token_type)

        else:
            common_vocab = list(set(self.p1.vocab) & set(self.p2.vocab))
            if not common_vocab:
                raise ValueError("Potentials in product must share a common vocabulary")

            self._check_vocab_overlap(common_vocab, self.p1, self.p2, threshold=0.1)

            self._v1_idxs = None
            self._v2_idxs = None

            super().__init__(common_vocab, token_type=token_type)

    def _check_vocab_overlap(self, common_vocab, p1, p2, threshold=0.1):
        for potential, name in [(p1, "p1"), (p2, "p2")]:
            overlap_ratio = len(common_vocab) / len(potential.vocab)
            if overlap_ratio < threshold:
                warnings.warn(
                    f"Common vocabulary ({len(common_vocab)} tokens) is less than {threshold * 100}% "
                    f"of {name}'s ({potential!r}) vocabulary ({len(potential.vocab)} tokens). "
                    "This Product potential only operates on this relatively small subset of tokens.",
                    RuntimeWarning,
                )

    @property
    def v1_idxs(self):
        if self._v1_idxs is None:
            self._v1_idxs = [self.p1.lookup[token] for token in self.vocab_eos]
        return self._v1_idxs

    @property
    def v2_idxs(self):
        if self._v2_idxs is None:
            self._v2_idxs = [self.p2.lookup[token] for token in self.vocab_eos]
        return self._v2_idxs

    async def prefix(self, context):
        w1 = await self.p1.prefix(context)
        if w1 == float("-inf"):
            return float("-inf")
        w2 = await self.p2.prefix(context)
        return w1 + w2

    async def complete(self, context):
        w1 = await self.p1.complete(context)
        if w1 == float("-inf"):
            return float("-inf")
        w2 = await self.p2.complete(context)
        return w1 + w2

    async def batch_complete(self, contexts):
        W1, W2 = await asyncio.gather(
            self.p1.batch_complete(contexts), self.p2.batch_complete(contexts)
        )
        return W1 + W2

    async def batch_prefix(self, contexts):
        W1, W2 = await asyncio.gather(
            self.p1.batch_prefix(contexts), self.p2.batch_prefix(contexts)
        )
        return W1 + W2

    async def logw_next(self, context):
        W1, W2 = await asyncio.gather(
            self.p1.logw_next(context), self.p2.logw_next(context)
        )
        return self.make_lazy_weights(
            W1.weights[self.v1_idxs] + W2.weights[self.v2_idxs]
        )

    async def batch_logw_next(self, contexts):
        Ws1, Ws2 = await asyncio.gather(
            self.p1.batch_logw_next(contexts), self.p2.batch_logw_next(contexts)
        )
        return [
            self.make_lazy_weights(
                Ws1[n].weights[self.v1_idxs] + Ws2[n].weights[self.v2_idxs]
            )
            for n in range(len(contexts))
        ]

    def spawn(self, p1_opts=None, p2_opts=None):
        return Product(
            self.p1.spawn(**(p1_opts or {})),
            self.p2.spawn(**(p2_opts or {})),
        )

    def __repr__(self):
        return f"Product({self.p1!r}, {self.p2!r})"

__init__(p1, p2)

Initialize a Product potential.

Parameters:

Name Type Description Default
p1 Potential

First potential

required
p2 Potential

Second potential

required
Source code in genlm/control/potential/product.py
def __init__(self, p1, p2):
    """Initialize a Product potential.

    Args:
        p1 (Potential): First potential
        p2 (Potential): Second potential
    """
    self.p1 = p1
    self.p2 = p2

    if self.p1.token_type == self.p2.token_type:
        token_type = self.p1.token_type
    else:
        raise ValueError(
            "Potentials in product must have the same token type. "
            f"Got {self.p1.token_type} and {self.p2.token_type}."
            + (
                "\nMaybe you forgot to coerce the potentials to the same token type? See `Coerce`."
                if (
                    self.p1.token_type.is_iterable_of(self.p2.token_type)
                    or self.p2.token_type.is_iterable_of(self.p1.token_type)
                )
                else ""
            )
        )

    if self.p1.vocab == self.p2.vocab:
        self._v1_idxs = ...
        self._v2_idxs = ...
        super().__init__(self.p1.vocab, token_type=token_type)

    else:
        common_vocab = list(set(self.p1.vocab) & set(self.p2.vocab))
        if not common_vocab:
            raise ValueError("Potentials in product must share a common vocabulary")

        self._check_vocab_overlap(common_vocab, self.p1, self.p2, threshold=0.1)

        self._v1_idxs = None
        self._v2_idxs = None

        super().__init__(common_vocab, token_type=token_type)

Coerced

Bases: Potential

Coerce a potential to operate on another vocabulary.

This class allows a potential to be adapted to work with a different set of tokens, defined by a target vocabulary and coersion function.

This class inherits all methods from Potential. Each method delegates to the corresponding method of the underlying potential, but first maps any input token sequences from the target vocabulary to the original potential's vocabulary using the coercion function.

Formally, if \(f\) is the coercion function, then for any sequence \(x_1, \ldots, x_n\) of tokens from the target vocabulary, $$ \textsf{Coerced.prefix}(x_1, \ldots, x_n) = \textsf{Coerced.potential.prefix}(f(x_1, \ldots, x_n)) $$

\[ \textsf{Coerced.complete}(x_1, \ldots, x_n) = \textsf{Coerced.potential.complete}(f(x_1, \ldots, x_n)) \]

Attributes:

Name Type Description
potential Potential

The original potential instance that is being coerced.

f callable

A function that maps sequences of tokens from the target vocabulary to sequences of tokens from the original potential's vocabulary.

Note

The coerced potential's vocabulary will by default be pruned to only include tokens that can be mapped to the original potential's vocabulary via the coercion function (i.e. set(f([x])) <= set(potential.vocab)). If no such tokens are found, a ValueError is raised. This behavior can be overridden by setting prune=False, in which case the coerced potential's vocabulary will include all tokens from the target vocabulary.

Source code in genlm/control/potential/coerce.py
class Coerced(Potential):
    """
    Coerce a potential to operate on another vocabulary.

    This class allows a potential to be adapted to work with a different set of tokens,
    defined by a target vocabulary and coersion function.

    This class inherits all methods from [`Potential`][genlm.control.potential.base.Potential].
    Each method delegates to the corresponding method of the underlying potential, but first
    maps any input token sequences from the target vocabulary to the original potential's vocabulary
    using the coercion function.

    Formally, if $f$ is the coercion function, then for any sequence $x_1, \\ldots, x_n$ of tokens from the target vocabulary,
    $$
    \\textsf{Coerced.prefix}(x_1, \\ldots, x_n) = \\textsf{Coerced.potential.prefix}(f(x_1, \\ldots, x_n))
    $$

    $$
    \\textsf{Coerced.complete}(x_1, \\ldots, x_n) = \\textsf{Coerced.potential.complete}(f(x_1, \\ldots, x_n))
    $$

    Attributes:
        potential (Potential): The original potential instance that is being coerced.
        f (callable): A function that maps sequences of tokens from the target vocabulary to sequences of tokens from
            the original potential's vocabulary.

    Note:
        The coerced potential's vocabulary will by default be pruned to only include tokens that can be mapped to the original potential's vocabulary
        via the coercion function (i.e. `set(f([x])) <= set(potential.vocab)`). If no such tokens are found, a `ValueError` is raised.
        This behavior can be overridden by setting `prune=False`, in which case the coerced potential's vocabulary will include all tokens from the target vocabulary.
    """

    def __init__(self, potential, target_vocab, f, prune=True):
        """
        Initialize a Coerced potential.

        Args:
            potential (Potential): The original potential instance that is being coerced.
            target_vocab (list): The target vocabulary that the potential will operate on.
                Each element of `target_vocab` must be hashable.
            f (callable): A function that maps iterables of tokens from the target vocabulary
                to the original potential's vocabulary.
            prune (bool): Whether to prune the coerced potential's vocabulary to only include tokens that can be mapped to the original potential's vocabulary.
                If `False`, the coerced potential's vocabulary will include all tokens from the target vocabulary.

        Raises:
            ValueError: If no valid tokens are found in the target vocabulary that can be mapped to the original potential's vocabulary.
        """
        self.potential = potential
        self.f = f

        if prune:
            # When vocab contains Token objects (bytes subclass), the coercion
            # function f (typically b"".join) produces bytes. set(bytes) yields
            # int byte values, so we need potential_items to also be int byte
            # values for the subset check to work.
            if potential.vocab and isinstance(potential.vocab[0], Token):
                potential_items = set(
                    byte_val for tok in potential.vocab for byte_val in tok.byte_string
                )
            else:
                potential_items = set(potential.vocab)

            tokens = []
            for target_token in target_vocab:
                base_token = f([target_token])
                if set(base_token) <= potential_items:
                    tokens.append(target_token)
        else:
            tokens = target_vocab

        if not tokens:
            raise ValueError("No valid tokens found in target vocabulary")

        super().__init__(tokens)

    def _batch_f(self, contexts):
        return [self.f(context) for context in contexts]

    async def complete(self, context):
        return await self.potential.complete(context=self.f(context))

    async def prefix(self, context):
        return await self.potential.prefix(context=self.f(context))

    async def logw_next(self, context):
        Ws = self.alloc_logws()
        ctx = self.f(context)
        ctx_w = await self.potential.prefix(ctx)
        Ws[-1] = await self.potential.complete(ctx) - ctx_w
        exts = [self.f(chain(context, [x])) for x in self.vocab]  # slow!!
        Ws[:-1] = await self.potential.batch_prefix(exts) - ctx_w
        return self.make_lazy_weights(Ws)

    async def batch_complete(self, contexts):
        return await self.potential.batch_complete(contexts=self._batch_f(contexts))

    async def batch_prefix(self, contexts):
        return await self.potential.batch_prefix(contexts=self._batch_f(contexts))

    async def batch_logw_next(self, contexts):
        return await asyncio.gather(*[self.logw_next(context) for context in contexts])

    def __repr__(self):
        return f"{self.__class__.__name__}({self.potential!r})"

__init__(potential, target_vocab, f, prune=True)

Initialize a Coerced potential.

Parameters:

Name Type Description Default
potential Potential

The original potential instance that is being coerced.

required
target_vocab list

The target vocabulary that the potential will operate on. Each element of target_vocab must be hashable.

required
f callable

A function that maps iterables of tokens from the target vocabulary to the original potential's vocabulary.

required
prune bool

Whether to prune the coerced potential's vocabulary to only include tokens that can be mapped to the original potential's vocabulary. If False, the coerced potential's vocabulary will include all tokens from the target vocabulary.

True

Raises:

Type Description
ValueError

If no valid tokens are found in the target vocabulary that can be mapped to the original potential's vocabulary.

Source code in genlm/control/potential/coerce.py
def __init__(self, potential, target_vocab, f, prune=True):
    """
    Initialize a Coerced potential.

    Args:
        potential (Potential): The original potential instance that is being coerced.
        target_vocab (list): The target vocabulary that the potential will operate on.
            Each element of `target_vocab` must be hashable.
        f (callable): A function that maps iterables of tokens from the target vocabulary
            to the original potential's vocabulary.
        prune (bool): Whether to prune the coerced potential's vocabulary to only include tokens that can be mapped to the original potential's vocabulary.
            If `False`, the coerced potential's vocabulary will include all tokens from the target vocabulary.

    Raises:
        ValueError: If no valid tokens are found in the target vocabulary that can be mapped to the original potential's vocabulary.
    """
    self.potential = potential
    self.f = f

    if prune:
        # When vocab contains Token objects (bytes subclass), the coercion
        # function f (typically b"".join) produces bytes. set(bytes) yields
        # int byte values, so we need potential_items to also be int byte
        # values for the subset check to work.
        if potential.vocab and isinstance(potential.vocab[0], Token):
            potential_items = set(
                byte_val for tok in potential.vocab for byte_val in tok.byte_string
            )
        else:
            potential_items = set(potential.vocab)

        tokens = []
        for target_token in target_vocab:
            base_token = f([target_token])
            if set(base_token) <= potential_items:
                tokens.append(target_token)
    else:
        tokens = target_vocab

    if not tokens:
        raise ValueError("No valid tokens found in target vocabulary")

    super().__init__(tokens)

HarmonyPotential

Bases: Potential

A potential that applies a base constraint to specific channels of the Harmony chat format.

The Harmony chat format structures LLM output into named channels (analysis, final, commentary). This potential extracts the content of specified channels and evaluates them under a base potential, leaving unconstrained channels free.

Attributes:

Name Type Description
base_potential Potential

The potential applied to constrained channel contents.

harmony_chat HarmonyChat

Parser for the Harmony chat format.

constrained_channels list[str]

Channels to which the base potential is applied.

Source code in genlm/control/potential/harmony.py
class HarmonyPotential(Potential):
    """A potential that applies a base constraint to specific channels of the Harmony chat format.

    The Harmony chat format structures LLM output into named channels (analysis, final, commentary).
    This potential extracts the content of specified channels and evaluates them under a base
    potential, leaving unconstrained channels free.

    Attributes:
        base_potential (Potential): The potential applied to constrained channel contents.
        harmony_chat (HarmonyChat): Parser for the Harmony chat format.
        constrained_channels (list[str]): Channels to which the base potential is applied.
    """

    def __init__(
        self,
        base_potential: Potential,
        llm_tokenizer: Any,
        constrained_channels: list[str],
    ) -> None:
        """Initialize the HarmonyPotential.

        Args:
            base_potential (Potential): A base potential applied to the constrained channels.
            llm_tokenizer: A tokenizer that supports the harmony chat format.
            constrained_channels (list[str]): A non-empty list of channels to constrain.
                Each element must be one of ``"analysis"``, ``"final"``, or ``"commentary"``.

        Raises:
            ValueError: If ``constrained_channels`` is empty or contains invalid channel names.
            AssertionError: If the base potential's vocabulary is not a subset of the
                harmony potential's vocabulary.
        """
        if not constrained_channels:
            raise ValueError("constrained_channels must be a non-empty list.")
        invalid = set(constrained_channels) - VALID_CHANNELS
        if invalid:
            raise ValueError(
                f"Invalid channel names: {invalid}. Must be one of {VALID_CHANNELS}."
            )

        self.base_potential = base_potential
        self.harmony_chat = HarmonyChat(llm_tokenizer)
        self.constrained_channels = constrained_channels

        super().__init__(self.harmony_chat.potential_vocab)

        assert set(base_potential.vocab) <= set(
            self.vocab
        ), "The base potential's vocabulary must be a subset of the harmony potential's vocabulary."

    async def complete(self, context: list[bytes]) -> float:
        """Compute the log weight of the constrained channels as complete sequences.

        Args:
            context (list[bytes]): A list of byte tokens.

        Returns:
            (float): The sum (in log space) of the base potential's complete weight for each
                constrained channel. Returns 0 if no constrained channel is present.
        """
        channels = self.harmony_chat.extract_harmony_channels_from_tokens(context)

        coroutines = [
            self.base_potential.complete(channels[key]["content"])
            for key in channels
            if channels[key] is not None and key in self.constrained_channels
        ]
        if not coroutines:
            return 0.0
        results = await asyncio.gather(*coroutines)
        return sum(results)

    async def prefix(self, context: list[bytes]) -> float:
        """Compute the log weight of the constrained channels as a prefix.

        Each constrained channel is evaluated with the base potential: completed
        channels use ``complete``, while the currently open channel uses ``prefix``.

        Args:
            context (list[bytes]): A list of byte tokens.

        Returns:
            (float): The sum (in log space) of the base potential's weight for each
                constrained channel. Returns 0 if no constrained channel is present.
        """
        channels = self.harmony_chat.extract_harmony_channels_from_tokens(context)
        coroutines = []
        for key in channels:
            if channels[key] is not None and key in self.constrained_channels:
                if channels[key]["is_prefix"]:
                    coroutines.append(
                        self.base_potential.prefix(channels[key]["content"])
                    )
                else:
                    # Completed channels also contribute to the prefix weight.
                    coroutines.append(
                        self.base_potential.complete(channels[key]["content"])
                    )
        if not coroutines:
            return 0.0
        results = await asyncio.gather(*coroutines)
        return sum(results)

    async def logw_next(self, context: list[bytes]) -> LazyWeights:
        """Compute next-token log weights for each possible next token, including EOS.

        Args:
            context (list[bytes]): A list of byte tokens.

        Returns:
            (LazyWeights): Weights of each token in the vocabulary and EOS.

        Note:
            In the harmony chat format, the analysis and commentary channels are
            closed by the ``<|end|>`` token, while the final channel is closed by
            ``<|return|>`` (which also closes the chat and halts generation).

            The base potential uses the built-in EOS symbol to represent "the
            constrained string ends here." We need to remap this to the token
            the LLM actually emits to close the channel:

            - **analysis/commentary**: Move the base potential's EOS weight to the
              ``<|end|>`` token and set EOS to -inf, since generation must not halt
              mid-turn.
            - **final**: No remapping needed, because ``PromptedLLM`` already moves
              ``<|return|>`` probability to EOS, so the base potential and the LLM
              are already aligned.
        """

        channels = self.harmony_chat.extract_harmony_channels_from_tokens(context)

        next_token_weights = self.make_lazy_weights(np.zeros(len(self.vocab_eos)))
        incomplete_channels = {
            key
            for key in channels
            if channels[key] is not None and channels[key]["is_prefix"]
        }
        assert (
            len(incomplete_channels) <= 1
        ), "At most one channel can have the 'is_prefix' flag set to true."

        if len(incomplete_channels) == 0:
            return next_token_weights  # pragma: no cover

        key = incomplete_channels.pop()
        if key is not None and key in self.constrained_channels:
            if await self.base_potential.prefix(channels[key]["content"]) == float(
                "-inf"
            ):
                raise ValueError(  # pragma: no cover
                    f"Context {channels[key]['content']!r} has weight zero under `prefix`."
                )

            next_token_weights.weights += (
                await self.base_potential.logw_next(channels[key]["content"])
            ).weights

            if key == "analysis" or key == "commentary":
                # The base potential's EOS weight represents "string is complete."
                # Remap it to <|end|> (which the LLM emits to close these channels)
                # and set EOS to -inf to prevent the LLM from halting mid-turn.
                eos_weight = next_token_weights.weights[-1]
                idx = next_token_weights.encode[self.harmony_chat.end_token]
                next_token_weights.weights[idx] = eos_weight
                next_token_weights.weights[-1] = float("-inf")

            # For the final channel, no remapping is needed: PromptedLLM already
            # maps <|return|> to EOS, so the base potential's EOS weight is
            # already aligned with the LLM's halting token.

        return next_token_weights

__init__(base_potential, llm_tokenizer, constrained_channels)

Initialize the HarmonyPotential.

Parameters:

Name Type Description Default
base_potential Potential

A base potential applied to the constrained channels.

required
llm_tokenizer Any

A tokenizer that supports the harmony chat format.

required
constrained_channels list[str]

A non-empty list of channels to constrain. Each element must be one of "analysis", "final", or "commentary".

required

Raises:

Type Description
ValueError

If constrained_channels is empty or contains invalid channel names.

AssertionError

If the base potential's vocabulary is not a subset of the harmony potential's vocabulary.

Source code in genlm/control/potential/harmony.py
def __init__(
    self,
    base_potential: Potential,
    llm_tokenizer: Any,
    constrained_channels: list[str],
) -> None:
    """Initialize the HarmonyPotential.

    Args:
        base_potential (Potential): A base potential applied to the constrained channels.
        llm_tokenizer: A tokenizer that supports the harmony chat format.
        constrained_channels (list[str]): A non-empty list of channels to constrain.
            Each element must be one of ``"analysis"``, ``"final"``, or ``"commentary"``.

    Raises:
        ValueError: If ``constrained_channels`` is empty or contains invalid channel names.
        AssertionError: If the base potential's vocabulary is not a subset of the
            harmony potential's vocabulary.
    """
    if not constrained_channels:
        raise ValueError("constrained_channels must be a non-empty list.")
    invalid = set(constrained_channels) - VALID_CHANNELS
    if invalid:
        raise ValueError(
            f"Invalid channel names: {invalid}. Must be one of {VALID_CHANNELS}."
        )

    self.base_potential = base_potential
    self.harmony_chat = HarmonyChat(llm_tokenizer)
    self.constrained_channels = constrained_channels

    super().__init__(self.harmony_chat.potential_vocab)

    assert set(base_potential.vocab) <= set(
        self.vocab
    ), "The base potential's vocabulary must be a subset of the harmony potential's vocabulary."

complete(context) async

Compute the log weight of the constrained channels as complete sequences.

Parameters:

Name Type Description Default
context list[bytes]

A list of byte tokens.

required

Returns:

Type Description
float

The sum (in log space) of the base potential's complete weight for each constrained channel. Returns 0 if no constrained channel is present.

Source code in genlm/control/potential/harmony.py
async def complete(self, context: list[bytes]) -> float:
    """Compute the log weight of the constrained channels as complete sequences.

    Args:
        context (list[bytes]): A list of byte tokens.

    Returns:
        (float): The sum (in log space) of the base potential's complete weight for each
            constrained channel. Returns 0 if no constrained channel is present.
    """
    channels = self.harmony_chat.extract_harmony_channels_from_tokens(context)

    coroutines = [
        self.base_potential.complete(channels[key]["content"])
        for key in channels
        if channels[key] is not None and key in self.constrained_channels
    ]
    if not coroutines:
        return 0.0
    results = await asyncio.gather(*coroutines)
    return sum(results)

prefix(context) async

Compute the log weight of the constrained channels as a prefix.

Each constrained channel is evaluated with the base potential: completed channels use complete, while the currently open channel uses prefix.

Parameters:

Name Type Description Default
context list[bytes]

A list of byte tokens.

required

Returns:

Type Description
float

The sum (in log space) of the base potential's weight for each constrained channel. Returns 0 if no constrained channel is present.

Source code in genlm/control/potential/harmony.py
async def prefix(self, context: list[bytes]) -> float:
    """Compute the log weight of the constrained channels as a prefix.

    Each constrained channel is evaluated with the base potential: completed
    channels use ``complete``, while the currently open channel uses ``prefix``.

    Args:
        context (list[bytes]): A list of byte tokens.

    Returns:
        (float): The sum (in log space) of the base potential's weight for each
            constrained channel. Returns 0 if no constrained channel is present.
    """
    channels = self.harmony_chat.extract_harmony_channels_from_tokens(context)
    coroutines = []
    for key in channels:
        if channels[key] is not None and key in self.constrained_channels:
            if channels[key]["is_prefix"]:
                coroutines.append(
                    self.base_potential.prefix(channels[key]["content"])
                )
            else:
                # Completed channels also contribute to the prefix weight.
                coroutines.append(
                    self.base_potential.complete(channels[key]["content"])
                )
    if not coroutines:
        return 0.0
    results = await asyncio.gather(*coroutines)
    return sum(results)

logw_next(context) async

Compute next-token log weights for each possible next token, including EOS.

Parameters:

Name Type Description Default
context list[bytes]

A list of byte tokens.

required

Returns:

Type Description
LazyWeights

Weights of each token in the vocabulary and EOS.

Note

In the harmony chat format, the analysis and commentary channels are closed by the <|end|> token, while the final channel is closed by <|return|> (which also closes the chat and halts generation).

The base potential uses the built-in EOS symbol to represent "the constrained string ends here." We need to remap this to the token the LLM actually emits to close the channel:

  • analysis/commentary: Move the base potential's EOS weight to the <|end|> token and set EOS to -inf, since generation must not halt mid-turn.
  • final: No remapping needed, because PromptedLLM already moves <|return|> probability to EOS, so the base potential and the LLM are already aligned.
Source code in genlm/control/potential/harmony.py
async def logw_next(self, context: list[bytes]) -> LazyWeights:
    """Compute next-token log weights for each possible next token, including EOS.

    Args:
        context (list[bytes]): A list of byte tokens.

    Returns:
        (LazyWeights): Weights of each token in the vocabulary and EOS.

    Note:
        In the harmony chat format, the analysis and commentary channels are
        closed by the ``<|end|>`` token, while the final channel is closed by
        ``<|return|>`` (which also closes the chat and halts generation).

        The base potential uses the built-in EOS symbol to represent "the
        constrained string ends here." We need to remap this to the token
        the LLM actually emits to close the channel:

        - **analysis/commentary**: Move the base potential's EOS weight to the
          ``<|end|>`` token and set EOS to -inf, since generation must not halt
          mid-turn.
        - **final**: No remapping needed, because ``PromptedLLM`` already moves
          ``<|return|>`` probability to EOS, so the base potential and the LLM
          are already aligned.
    """

    channels = self.harmony_chat.extract_harmony_channels_from_tokens(context)

    next_token_weights = self.make_lazy_weights(np.zeros(len(self.vocab_eos)))
    incomplete_channels = {
        key
        for key in channels
        if channels[key] is not None and channels[key]["is_prefix"]
    }
    assert (
        len(incomplete_channels) <= 1
    ), "At most one channel can have the 'is_prefix' flag set to true."

    if len(incomplete_channels) == 0:
        return next_token_weights  # pragma: no cover

    key = incomplete_channels.pop()
    if key is not None and key in self.constrained_channels:
        if await self.base_potential.prefix(channels[key]["content"]) == float(
            "-inf"
        ):
            raise ValueError(  # pragma: no cover
                f"Context {channels[key]['content']!r} has weight zero under `prefix`."
            )

        next_token_weights.weights += (
            await self.base_potential.logw_next(channels[key]["content"])
        ).weights

        if key == "analysis" or key == "commentary":
            # The base potential's EOS weight represents "string is complete."
            # Remap it to <|end|> (which the LLM emits to close these channels)
            # and set EOS to -inf to prevent the LLM from halting mid-turn.
            eos_weight = next_token_weights.weights[-1]
            idx = next_token_weights.encode[self.harmony_chat.end_token]
            next_token_weights.weights[idx] = eos_weight
            next_token_weights.weights[-1] = float("-inf")

        # For the final channel, no remapping is needed: PromptedLLM already
        # maps <|return|> to EOS, so the base potential's EOS weight is
        # already aligned with the LLM's halting token.

    return next_token_weights

HarmonyChat

Encodes the structure of the "assistant" field of the Harmony chat format.

Provides methods to extract the "harmony channels" (analysis, final, commentary) from it. Since it operates on the byte representation of tokens, it also provides methods to convert between token IDs and byte representations.

Attributes:

Name Type Description
tokenizer

The tokenizer used to encode and decode tokens.

token_maps TokenMappings

Mappings between token IDs and byte representations.

potential_vocab list[bytes]

The byte vocabulary used by potentials.

end_token bytes

Byte representation of the <|end|> token.

message_token bytes

Byte representation of the <|message|> token.

channel_token bytes

Byte representation of the <|channel|> token.

analysis_tokens list[bytes]

Byte representation of the "analysis" string.

final_tokens list[bytes]

Byte representation of the "final" string.

commentary_tokens list[bytes]

Byte representation of the "commentary" string.

Source code in genlm/control/potential/harmony.py
class HarmonyChat:
    """Encodes the structure of the "assistant" field of the Harmony chat format.

    Provides methods to extract the "harmony channels" (analysis, final, commentary)
    from it. Since it operates on the byte representation of tokens, it also provides
    methods to convert between token IDs and byte representations.

    Attributes:
        tokenizer: The tokenizer used to encode and decode tokens.
        token_maps (TokenMappings): Mappings between token IDs and byte representations.
        potential_vocab (list[bytes]): The byte vocabulary used by potentials.
        end_token (bytes): Byte representation of the ``<|end|>`` token.
        message_token (bytes): Byte representation of the ``<|message|>`` token.
        channel_token (bytes): Byte representation of the ``<|channel|>`` token.
        analysis_tokens (list[bytes]): Byte representation of the ``"analysis"`` string.
        final_tokens (list[bytes]): Byte representation of the ``"final"`` string.
        commentary_tokens (list[bytes]): Byte representation of the ``"commentary"`` string.
    """

    def __init__(self, tokenizer: Any) -> None:
        """
        Initialize HarmonyChat with a tokenizer.

        Args:
            tokenizer: A tokenizer that supports the harmony chat format.
                The tokenizer must be able to encode the harmony chat tokens
                as single tokens.

        """
        # Check that the tokenizer object has the minimum required methods.
        assert hasattr(tokenizer, "encode"), "Tokenizer is missing the 'encode' method."
        assert hasattr(tokenizer, "decode"), "Tokenizer is missing the 'decode' method."
        assert hasattr(
            tokenizer, "apply_chat_template"
        ), "Tokenizer is missing the 'apply_chat_template' method."

        # Check that the tokenizer supports the special tokens of the harmony chat format
        # (in which case they should all be encoded as single tokens).
        for token in [
            "<|start|>",
            "<|channel|>",
            "<|message|>",
            "<|end|>",
            "<|return|>",
        ]:
            assert len(tokenizer.encode(token)) == 1, (
                f"Token {token!r} is not encoded as a single token. "
                "The tokenizer does not appear to support the harmony chat format."
            )

        self.tokenizer = tokenizer
        _byte_vocab, _ = decode_vocab(
            tokenizer
        )  # Byte representation of each token. Follows the same schema as PromptedLLM.
        _eos_byte_strings = [
            _byte_vocab[
                tokenizer.eos_token_id
            ].byte_string  # for gpt-oss, this is the <|return|> token.
        ]

        self.token_maps = TokenMappings.create(
            decode=_byte_vocab, eos_byte_strings=_eos_byte_strings
        )
        self.potential_vocab = self.token_maps.potential_vocab

        # Store the byte representation of special tokens needed for harmony channel parsing.
        self.end_token = self.decode_tokens(self.tokenizer.encode("<|end|>"))[0]
        self.message_token = self.decode_tokens(self.tokenizer.encode("<|message|>"))[0]
        self.channel_token = self.decode_tokens(self.tokenizer.encode("<|channel|>"))[0]
        self.analysis_tokens = self.decode_tokens(
            self.tokenizer.encode("analysis")
        )  # The following tokens (analysis, commentary, final) are not reserved, and therefore they are not guaranteed to be single tokens.
        self.final_tokens = self.decode_tokens(self.tokenizer.encode("final"))
        self.commentary_tokens = self.decode_tokens(self.tokenizer.encode("commentary"))

    def extract_channel_content(
        self, token_bytes: list[bytes], i: int
    ) -> dict[str, Union[list[bytes], bool]] | None:
        """Extract content between the ``<|message|>`` token at position ``i`` and the next ``<|end|>`` token.

        Args:
            token_bytes (list[bytes]): The full token sequence.
            i (int): Index of the ``<|message|>`` token.

        Returns:
            (dict | None): A dict with keys ``"content"`` (list of byte tokens) and
                ``"is_prefix"`` (bool), or ``None`` if ``i`` is out of bounds.
        """

        if i >= len(token_bytes):
            return None  # pragma: no cover
        i += 1
        if self.end_token in token_bytes[i:]:
            end_position = token_bytes.index(self.end_token, i)
            content = token_bytes[i:end_position]
            is_prefix = False
        else:
            content = token_bytes[i:]
            is_prefix = True

        return {"content": content, "is_prefix": is_prefix}

    def extract_harmony_channels_from_tokens(
        self, token_bytes: list[bytes]
    ) -> dict[str, dict[str, Union[list[bytes], bool]] | None]:
        """Extract analysis, final, and commentary content from token bytes.

        Args:
            token_bytes (list[bytes]): List of byte tokens.

        Returns:
            (dict): A dictionary mapping channel names to their extracted content,
                or ``None`` if the channel is not present.

        Raises:
            AssertionError: If the token bytes do not form a valid harmony chat.
        """

        assert self.validate_harmony_format(
            token_bytes
        ), f"The context is not a valid harmony chat: {token_bytes}"
        results = {"analysis": None, "final": None, "commentary": None}

        for i, token in enumerate(token_bytes[:-2]):
            # The harmony format assumes that the <|channel|> token is immediately followed by the channel type, thus we can stop before the last two tokens.
            # Look for <|channel|> token followed by analysis/final/commentary.
            if token == self.channel_token:
                j = i + 1
                # Check whether the analysis, final or commentary tokens follow the channel opening.
                if (
                    len(token_bytes) >= j + len(self.analysis_tokens)
                    and token_bytes[j : j + len(self.analysis_tokens)]
                    == self.analysis_tokens
                ):
                    results["analysis"] = self.extract_channel_content(
                        token_bytes, j + len(self.analysis_tokens)
                    )
                elif (
                    len(token_bytes) >= j + len(self.final_tokens)
                    and token_bytes[j : j + len(self.final_tokens)] == self.final_tokens
                ):
                    results["final"] = self.extract_channel_content(
                        token_bytes, j + len(self.final_tokens)
                    )
                elif (
                    len(token_bytes) >= j + len(self.commentary_tokens)
                    and token_bytes[j : j + len(self.commentary_tokens)]
                    == self.commentary_tokens
                ):
                    results["commentary"] = self.extract_channel_content(
                        token_bytes, j + len(self.commentary_tokens)
                    )

        return results

    def extract_harmony_channels_from_string(
        self, string: str, add_special_tokens: bool = False
    ) -> dict[str, dict[str, Union[list[bytes], bool]] | None]:
        """Extract analysis, final, and commentary content from a string.

        Uses the tokenizer to map from string to token IDs and from token IDs to token bytes,
        then calls :meth:`extract_harmony_channels_from_tokens`.

        Args:
            string (str): The harmony chat format string to extract channels from.
            add_special_tokens (bool): Whether to add special tokens during encoding.

        Returns:
            (dict): A dictionary mapping channel names to their extracted content
                (same format as :meth:`extract_harmony_channels_from_tokens`).
        """
        token_ids = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
        token_bytes = self.decode_tokens(token_ids)
        return self.extract_harmony_channels_from_tokens(token_bytes)

    def encode_tokens(self, tokens):
        """Encode a list of Token objects (or bytes) to token IDs.

        Args:
            tokens (list[Token] | list[bytes]): List of tokens to encode.

        Returns:
            (list[int]): A list of token IDs corresponding to the input tokens.

        Raises:
            ValueError: If any token is not in the vocabulary.
        """
        result = []
        for item in tokens:
            if isinstance(item, Token):
                result.append(item.token_id)
            elif isinstance(item, bytes):
                # Fallback: cached lookup by byte_string (first match)
                if not hasattr(self, "_bytes_to_token_id"):
                    self._bytes_to_token_id = {}
                    for tok in self.token_maps.decode:
                        if tok.byte_string not in self._bytes_to_token_id:
                            self._bytes_to_token_id[tok.byte_string] = tok.token_id
                tid = self._bytes_to_token_id.get(item)
                if tid is None:  # pragma: no cover
                    raise ValueError(f"Token {item!r} not in vocabulary")
                result.append(tid)
            else:  # pragma: no cover
                raise TypeError(f"Expected Token or bytes, got {type(item)}")
        return result

    def decode_tokens(self, ids: list[int]) -> list[bytes]:
        """Decode a list of token IDs to byte tokens.

        Args:
            ids (list[int]): A list of token IDs in the language model's vocabulary.

        Returns:
            (list[bytes]): A list of byte tokens corresponding to the input token IDs.
        """
        assert all(isinstance(x, int) for x in ids), "Token IDs must be integers"
        return [self.token_maps.decode[x] for x in ids]

    def validate_harmony_format(self, context: Union[str, list[bytes]]) -> bool:
        """Validate that the context is a valid harmony chat.

        Validates the "assistant" field of the chat format, which is generated
        by the language model.

        Args:
            context (str | list[bytes]): A string or a list of byte tokens.

        Returns:
            (bool): ``True`` if the context is a valid harmony chat, ``False`` otherwise.
        """
        if (
            isinstance(context, list) and len(context) > 0 and context[-1] == EOS
        ):  # Remove the EOS token if present.
            context = context[:-1]  # pragma: no cover

        if isinstance(context, list) and all(
            isinstance(x, (bytes, Token)) for x in context
        ):
            byte_parts = [Token.as_bytes(x) for x in context]
            context_str = b"".join(byte_parts).decode("utf-8", errors="replace")
        elif isinstance(context, str):  # pragma: no cover
            context_str = context  # pragma: no cover
        else:  # pragma: no cover
            raise ValueError(
                f"Context must be a string or a list of bytes tokens, got {type(context)}"
            )  # pragma: no cover

        pattern = r"""
            ^
            (?:
                (?:<\|start\|>assistant)? # The assistant field is optional, the first one is part of the prompt and not the generated tokens
                (?:\s+to=functions\.\w+)? # Optional Function call field
                <\|channel\|> # We start with the channel specifications
                (analysis|commentary|final) # Choose between the three possible channels
                (?:\s+json)?
                <\|message\|> # The message content begins
                (?:(?!<\|start\|>|<\|message\|>|<\|channel\|>|<\|call\|>|<\|return\|>).)*  # The actual message content can contain everything except the special tokens.
                (?:<\|end\|>|<\|call\|>|<\|return\|>) # The channel is closed by the <|end|>, <|call|>, or <|return|> tokens.
            )*
            $
        """

        match = regex.match(
            pattern, context_str, regex.VERBOSE | regex.DOTALL, partial=True
        )
        if not match:  # If the string does not match, we return False.
            return False  # pragma: no cover

        channel_types = match.captures(1)
        counts = Counter(
            channel_types
        )  # Validate that each channel is used at most once in a turn.
        if any(count > 1 for count in counts.values()):
            return False  # pragma: no cover
        return True

__init__(tokenizer)

Initialize HarmonyChat with a tokenizer.

Parameters:

Name Type Description Default
tokenizer Any

A tokenizer that supports the harmony chat format. The tokenizer must be able to encode the harmony chat tokens as single tokens.

required
Source code in genlm/control/potential/harmony.py
def __init__(self, tokenizer: Any) -> None:
    """
    Initialize HarmonyChat with a tokenizer.

    Args:
        tokenizer: A tokenizer that supports the harmony chat format.
            The tokenizer must be able to encode the harmony chat tokens
            as single tokens.

    """
    # Check that the tokenizer object has the minimum required methods.
    assert hasattr(tokenizer, "encode"), "Tokenizer is missing the 'encode' method."
    assert hasattr(tokenizer, "decode"), "Tokenizer is missing the 'decode' method."
    assert hasattr(
        tokenizer, "apply_chat_template"
    ), "Tokenizer is missing the 'apply_chat_template' method."

    # Check that the tokenizer supports the special tokens of the harmony chat format
    # (in which case they should all be encoded as single tokens).
    for token in [
        "<|start|>",
        "<|channel|>",
        "<|message|>",
        "<|end|>",
        "<|return|>",
    ]:
        assert len(tokenizer.encode(token)) == 1, (
            f"Token {token!r} is not encoded as a single token. "
            "The tokenizer does not appear to support the harmony chat format."
        )

    self.tokenizer = tokenizer
    _byte_vocab, _ = decode_vocab(
        tokenizer
    )  # Byte representation of each token. Follows the same schema as PromptedLLM.
    _eos_byte_strings = [
        _byte_vocab[
            tokenizer.eos_token_id
        ].byte_string  # for gpt-oss, this is the <|return|> token.
    ]

    self.token_maps = TokenMappings.create(
        decode=_byte_vocab, eos_byte_strings=_eos_byte_strings
    )
    self.potential_vocab = self.token_maps.potential_vocab

    # Store the byte representation of special tokens needed for harmony channel parsing.
    self.end_token = self.decode_tokens(self.tokenizer.encode("<|end|>"))[0]
    self.message_token = self.decode_tokens(self.tokenizer.encode("<|message|>"))[0]
    self.channel_token = self.decode_tokens(self.tokenizer.encode("<|channel|>"))[0]
    self.analysis_tokens = self.decode_tokens(
        self.tokenizer.encode("analysis")
    )  # The following tokens (analysis, commentary, final) are not reserved, and therefore they are not guaranteed to be single tokens.
    self.final_tokens = self.decode_tokens(self.tokenizer.encode("final"))
    self.commentary_tokens = self.decode_tokens(self.tokenizer.encode("commentary"))

extract_channel_content(token_bytes, i)

Extract content between the <|message|> token at position i and the next <|end|> token.

Parameters:

Name Type Description Default
token_bytes list[bytes]

The full token sequence.

required
i int

Index of the <|message|> token.

required

Returns:

Type Description
dict | None

A dict with keys "content" (list of byte tokens) and "is_prefix" (bool), or None if i is out of bounds.

Source code in genlm/control/potential/harmony.py
def extract_channel_content(
    self, token_bytes: list[bytes], i: int
) -> dict[str, Union[list[bytes], bool]] | None:
    """Extract content between the ``<|message|>`` token at position ``i`` and the next ``<|end|>`` token.

    Args:
        token_bytes (list[bytes]): The full token sequence.
        i (int): Index of the ``<|message|>`` token.

    Returns:
        (dict | None): A dict with keys ``"content"`` (list of byte tokens) and
            ``"is_prefix"`` (bool), or ``None`` if ``i`` is out of bounds.
    """

    if i >= len(token_bytes):
        return None  # pragma: no cover
    i += 1
    if self.end_token in token_bytes[i:]:
        end_position = token_bytes.index(self.end_token, i)
        content = token_bytes[i:end_position]
        is_prefix = False
    else:
        content = token_bytes[i:]
        is_prefix = True

    return {"content": content, "is_prefix": is_prefix}

extract_harmony_channels_from_tokens(token_bytes)

Extract analysis, final, and commentary content from token bytes.

Parameters:

Name Type Description Default
token_bytes list[bytes]

List of byte tokens.

required

Returns:

Type Description
dict

A dictionary mapping channel names to their extracted content, or None if the channel is not present.

Raises:

Type Description
AssertionError

If the token bytes do not form a valid harmony chat.

Source code in genlm/control/potential/harmony.py
def extract_harmony_channels_from_tokens(
    self, token_bytes: list[bytes]
) -> dict[str, dict[str, Union[list[bytes], bool]] | None]:
    """Extract analysis, final, and commentary content from token bytes.

    Args:
        token_bytes (list[bytes]): List of byte tokens.

    Returns:
        (dict): A dictionary mapping channel names to their extracted content,
            or ``None`` if the channel is not present.

    Raises:
        AssertionError: If the token bytes do not form a valid harmony chat.
    """

    assert self.validate_harmony_format(
        token_bytes
    ), f"The context is not a valid harmony chat: {token_bytes}"
    results = {"analysis": None, "final": None, "commentary": None}

    for i, token in enumerate(token_bytes[:-2]):
        # The harmony format assumes that the <|channel|> token is immediately followed by the channel type, thus we can stop before the last two tokens.
        # Look for <|channel|> token followed by analysis/final/commentary.
        if token == self.channel_token:
            j = i + 1
            # Check whether the analysis, final or commentary tokens follow the channel opening.
            if (
                len(token_bytes) >= j + len(self.analysis_tokens)
                and token_bytes[j : j + len(self.analysis_tokens)]
                == self.analysis_tokens
            ):
                results["analysis"] = self.extract_channel_content(
                    token_bytes, j + len(self.analysis_tokens)
                )
            elif (
                len(token_bytes) >= j + len(self.final_tokens)
                and token_bytes[j : j + len(self.final_tokens)] == self.final_tokens
            ):
                results["final"] = self.extract_channel_content(
                    token_bytes, j + len(self.final_tokens)
                )
            elif (
                len(token_bytes) >= j + len(self.commentary_tokens)
                and token_bytes[j : j + len(self.commentary_tokens)]
                == self.commentary_tokens
            ):
                results["commentary"] = self.extract_channel_content(
                    token_bytes, j + len(self.commentary_tokens)
                )

    return results

extract_harmony_channels_from_string(string, add_special_tokens=False)

Extract analysis, final, and commentary content from a string.

Uses the tokenizer to map from string to token IDs and from token IDs to token bytes, then calls :meth:extract_harmony_channels_from_tokens.

Parameters:

Name Type Description Default
string str

The harmony chat format string to extract channels from.

required
add_special_tokens bool

Whether to add special tokens during encoding.

False

Returns:

Type Description
dict

A dictionary mapping channel names to their extracted content (same format as :meth:extract_harmony_channels_from_tokens).

Source code in genlm/control/potential/harmony.py
def extract_harmony_channels_from_string(
    self, string: str, add_special_tokens: bool = False
) -> dict[str, dict[str, Union[list[bytes], bool]] | None]:
    """Extract analysis, final, and commentary content from a string.

    Uses the tokenizer to map from string to token IDs and from token IDs to token bytes,
    then calls :meth:`extract_harmony_channels_from_tokens`.

    Args:
        string (str): The harmony chat format string to extract channels from.
        add_special_tokens (bool): Whether to add special tokens during encoding.

    Returns:
        (dict): A dictionary mapping channel names to their extracted content
            (same format as :meth:`extract_harmony_channels_from_tokens`).
    """
    token_ids = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
    token_bytes = self.decode_tokens(token_ids)
    return self.extract_harmony_channels_from_tokens(token_bytes)

encode_tokens(tokens)

Encode a list of Token objects (or bytes) to token IDs.

Parameters:

Name Type Description Default
tokens list[Token] | list[bytes]

List of tokens to encode.

required

Returns:

Type Description
list[int]

A list of token IDs corresponding to the input tokens.

Raises:

Type Description
ValueError

If any token is not in the vocabulary.

Source code in genlm/control/potential/harmony.py
def encode_tokens(self, tokens):
    """Encode a list of Token objects (or bytes) to token IDs.

    Args:
        tokens (list[Token] | list[bytes]): List of tokens to encode.

    Returns:
        (list[int]): A list of token IDs corresponding to the input tokens.

    Raises:
        ValueError: If any token is not in the vocabulary.
    """
    result = []
    for item in tokens:
        if isinstance(item, Token):
            result.append(item.token_id)
        elif isinstance(item, bytes):
            # Fallback: cached lookup by byte_string (first match)
            if not hasattr(self, "_bytes_to_token_id"):
                self._bytes_to_token_id = {}
                for tok in self.token_maps.decode:
                    if tok.byte_string not in self._bytes_to_token_id:
                        self._bytes_to_token_id[tok.byte_string] = tok.token_id
            tid = self._bytes_to_token_id.get(item)
            if tid is None:  # pragma: no cover
                raise ValueError(f"Token {item!r} not in vocabulary")
            result.append(tid)
        else:  # pragma: no cover
            raise TypeError(f"Expected Token or bytes, got {type(item)}")
    return result

decode_tokens(ids)

Decode a list of token IDs to byte tokens.

Parameters:

Name Type Description Default
ids list[int]

A list of token IDs in the language model's vocabulary.

required

Returns:

Type Description
list[bytes]

A list of byte tokens corresponding to the input token IDs.

Source code in genlm/control/potential/harmony.py
def decode_tokens(self, ids: list[int]) -> list[bytes]:
    """Decode a list of token IDs to byte tokens.

    Args:
        ids (list[int]): A list of token IDs in the language model's vocabulary.

    Returns:
        (list[bytes]): A list of byte tokens corresponding to the input token IDs.
    """
    assert all(isinstance(x, int) for x in ids), "Token IDs must be integers"
    return [self.token_maps.decode[x] for x in ids]

validate_harmony_format(context)

Validate that the context is a valid harmony chat.

Validates the "assistant" field of the chat format, which is generated by the language model.

Parameters:

Name Type Description Default
context str | list[bytes]

A string or a list of byte tokens.

required

Returns:

Type Description
bool

True if the context is a valid harmony chat, False otherwise.

Source code in genlm/control/potential/harmony.py
def validate_harmony_format(self, context: Union[str, list[bytes]]) -> bool:
    """Validate that the context is a valid harmony chat.

    Validates the "assistant" field of the chat format, which is generated
    by the language model.

    Args:
        context (str | list[bytes]): A string or a list of byte tokens.

    Returns:
        (bool): ``True`` if the context is a valid harmony chat, ``False`` otherwise.
    """
    if (
        isinstance(context, list) and len(context) > 0 and context[-1] == EOS
    ):  # Remove the EOS token if present.
        context = context[:-1]  # pragma: no cover

    if isinstance(context, list) and all(
        isinstance(x, (bytes, Token)) for x in context
    ):
        byte_parts = [Token.as_bytes(x) for x in context]
        context_str = b"".join(byte_parts).decode("utf-8", errors="replace")
    elif isinstance(context, str):  # pragma: no cover
        context_str = context  # pragma: no cover
    else:  # pragma: no cover
        raise ValueError(
            f"Context must be a string or a list of bytes tokens, got {type(context)}"
        )  # pragma: no cover

    pattern = r"""
        ^
        (?:
            (?:<\|start\|>assistant)? # The assistant field is optional, the first one is part of the prompt and not the generated tokens
            (?:\s+to=functions\.\w+)? # Optional Function call field
            <\|channel\|> # We start with the channel specifications
            (analysis|commentary|final) # Choose between the three possible channels
            (?:\s+json)?
            <\|message\|> # The message content begins
            (?:(?!<\|start\|>|<\|message\|>|<\|channel\|>|<\|call\|>|<\|return\|>).)*  # The actual message content can contain everything except the special tokens.
            (?:<\|end\|>|<\|call\|>|<\|return\|>) # The channel is closed by the <|end|>, <|call|>, or <|return|> tokens.
        )*
        $
    """

    match = regex.match(
        pattern, context_str, regex.VERBOSE | regex.DOTALL, partial=True
    )
    if not match:  # If the string does not match, we return False.
        return False  # pragma: no cover

    channel_types = match.captures(1)
    counts = Counter(
        channel_types
    )  # Validate that each channel is used at most once in a turn.
    if any(count > 1 for count in counts.values()):
        return False  # pragma: no cover
    return True

PromptedLLM

Bases: Potential

A potential representing a language model conditioned on a fixed prompt prefix.

PromptedLLMs operate on byte sequences.

Notes on EOS Token Handling:

  • Tokens to treat as end-of-sequence tokens are specified via the eos_byte_strings argument.

  • These tokens are excluded from the potential's vocabulary and as such do not appear in the vocab attribute.

    This means they cannot appear in any input contexts to the potential nor in the output of logw_next. They can be used in the prompt however.

  • The log probability assigned to the genlm.control's reserved EOS token is the sum of the log probabilities of all the specified EOS tokens.

This class wraps an AsyncLM instance.

Source code in genlm/control/potential/built_in/llm.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
class PromptedLLM(Potential):
    """A potential representing a language model conditioned on a fixed prompt prefix.

    `PromptedLLM`s operate on byte sequences.

    Notes on EOS Token Handling:\n
    - Tokens to treat as end-of-sequence tokens are specified via the `eos_byte_strings` argument.\n
    - These tokens are excluded from the potential's vocabulary and as such do not appear in the `vocab` attribute.\n
        This means they cannot appear in any input contexts to the potential nor in the output of `logw_next`. They can be used in the prompt however.\n
    - The log probability assigned to the `genlm.control`'s reserved `EOS` token is the sum of the log probabilities of all the specified EOS tokens.\n

    This class wraps an `AsyncLM` instance.
    """

    def __init__(
        self,
        llm,
        prompt_ids=None,
        eos_byte_strings=None,
        temperature=1.0,
        token_maps=None,
        **kwargs,
    ):
        """`
        Initializes the PromptedLLM potential.

        Args:
            llm (AsyncLM): The language model to use.
            prompt_ids (list[int], optional): Optional prompt to use as a prompt prefix for all input contexts.
                Must be a list of token IDs. Defaults to None. The prompt ids can be set post-init via `prompt` or `prompt_ids`.
            eos_byte_strings (list[bytes], optional): List of tokens to treat as end-of-sequence tokens.
                Defaults to the EOS token of the language model's tokenizer.
            temperature (float, optional): The temperature to apply to the language model's logits. Defaults to 1.
            token_maps (TokenMappings, optional): A precomputed mapping of tokens to token IDs with the potential's vocabulary.
                If provided, `eos_byte_strings` must not be provided. Defaults to None, which constructs a TokenMappings from the language model's byte vocabulary and the EOS tokens.
        """
        eos_byte_strings = _compat_eos_tokens(eos_byte_strings, kwargs)
        self.model = llm
        self.prompt_ids = prompt_ids or []
        self.temperature = temperature

        if token_maps is not None:
            if eos_byte_strings is not None:
                raise ValueError(
                    "eos_byte_strings must not be provided when token_maps is provided."
                )
            self.token_maps = token_maps
        else:
            byte_vocab = self.model.byte_vocab
            default_eos = byte_vocab[self.model.tokenizer.eos_token_id].byte_string
            self.token_maps = TokenMappings.create(
                decode=byte_vocab,
                eos_byte_strings=eos_byte_strings or [default_eos],
            )

        super().__init__(vocabulary=self.token_maps.potential_vocab)

    @classmethod
    def from_name(
        cls,
        name,
        backend=None,
        eos_byte_strings=None,
        prompt_ids=None,
        temperature=1.0,
        **kwargs,
    ):
        """Create a `PromptedLLM` from a Hugging Face model name.

        Args:
            name (str): Name of the model to load
            backend (str, optional): `AsyncLM` backend to use:\n
                * 'vllm' to instantiate an `AsyncVirtualLM`; ideal for GPU usage\n
                * 'hf' for an `AsyncTransformer`; ideal for CPU usage\n
                * 'mlx' for an `AsyncMlxLM`; ideal for Apple silicon usage\n
                * 'mock' for a `MockAsyncLM`; ideal for testing.\n
                Defaults to 'vllm' if CUDA is available, otherwise 'hf'.
            eos_byte_strings (list[bytes], optional): List of tokens to treat as end-of-sequence tokens.
                Defaults to the EOS token of the language model's tokenizer.
            prompt_ids (list[int], optional): Optional prompt to use as a prompt prefix for all input contexts.
                Must be a list of token IDs. Defaults to None. The prompt ids can be set post-init via `set_prompt_from_str` or `prompt_ids`.
            temperature (float, optional): The temperature to apply to the language model's logits. Defaults to 1.
            **kwargs (dict): Additional arguments passed to AsyncLM constructor

        Returns:
            (PromptedLLM): An instance of PromptedLLM
        """
        eos_byte_strings = _compat_eos_tokens(eos_byte_strings, kwargs)
        backend = backend or ("vllm" if torch.cuda.is_available() else "hf")
        model = load_model_by_name(name, backend=backend, **kwargs)
        return cls(
            model, prompt_ids=prompt_ids, eos_byte_strings=eos_byte_strings, temperature=temperature
        )

    @property
    def eos_byte_strings(self):
        return self.token_maps.eos_byte_strings

    @eos_byte_strings.setter
    def eos_byte_strings(self, value):
        raise ValueError(
            "Cannot reset eos_byte_strings after initialization. "
            "Use spawn_new_eos(new_eos_byte_strings) instead."
        )

    @property
    def prompt(self):
        """
        Get the current prompt as Token objects.

        Returns:
            (list[Token]|None): The current prompt as Token objects, or None if no prompt_ids are set.
        """
        if not self.prompt_ids:
            return  # pragma: no cover
        return [self.token_maps.decode[x] for x in self.prompt_ids]

    def set_prompt_from_str(self, prompt_str):
        """Set the fixed prompt from a string.

        Modifies `prompt_ids` to be the token IDs of the input prompt according to the language model's tokenizer.

        Args:
            prompt_str (str): The prompt to set.
        """
        # TODO: Handle race condition where prompt_ids reset concurrently.
        if not isinstance(prompt_str, str):
            raise ValueError(
                f"Prompt must a string got {type(prompt_str)}. "
                f"To set the prompt from a list of token IDs, use prompt_ids."
            )

        if prompt_str.endswith(" "):
            warnings.warn(
                "Prompt ends with whitespace, which may affect tokenization. "
                "Consider removing trailing whitespace.",
                stacklevel=2,
            )

        self.prompt_ids = self.model.tokenizer.encode(prompt_str)

    def _find_token_id_for_bytes(self, byte_string):
        """Find token_id for a byte_string (first match for duplicates).

        Uses a lazily-built cache for O(1) lookup. For duplicate byte strings,
        returns the first token_id encountered in the vocabulary.
        """
        if not hasattr(self, "_bytes_to_token_id"):
            # Build reverse map: bytes → first token_id. Later entries don't
            # overwrite, so the first match wins (consistent with old behavior).
            self._bytes_to_token_id = {}
            for token in self.token_maps.decode:
                if token.byte_string not in self._bytes_to_token_id:
                    self._bytes_to_token_id[token.byte_string] = token.token_id
        return self._bytes_to_token_id.get(byte_string)

    def encode_tokens(self, tokens):
        """Encode a list of Token objects to token IDs.

        Args:
            tokens (list[Token]): List of Token objects

        Returns:
            (list[int]): A list of token IDs corresponding to the input tokens.

        Raises:
            ValueError: If any token is not in the vocabulary.

        Note:
            Passing bytes is deprecated. Use Token objects from llm.tokenize().
        """
        if not tokens:
            return []

        result = []
        warned = False
        for item in tokens:
            if isinstance(item, Token):
                result.append(item.token_id)
            else:
                if not warned:
                    warnings.warn(
                        "Passing bytes to encode_tokens is deprecated. "
                        "Use Token objects for precise control. ",
                        DeprecationWarning,
                        stacklevel=3,
                    )
                    warned = True
                token_id = self._find_token_id_for_bytes(item)
                if token_id is None:
                    raise ValueError(f"Token {item!r} not in vocabulary")
                result.append(token_id)
        return result

    def decode_tokens(self, ids):
        """
        Decode a list of token IDs to Token objects.

        Args:
            ids (list[int]): A list of token IDs in the language model's vocabulary.

        Returns:
            (list[Token]): Token objects corresponding to the input token IDs.
        """
        return [self.token_maps.decode[x] for x in ids]

    def tokenize(self, context_str):
        """Tokenize a string to a list of Token objects.

        Uses the language model's tokenizer to map `context_str` to token IDs,
        then returns the corresponding Token objects.

        Args:
            context_str (str): A string to encode

        Returns:
            (list[Token]): Token objects corresponding to the input string.
        """
        return self.decode_tokens(self.model.tokenizer.encode(context_str))

    async def log_probability(self, context):
        """
        Compute the log probability of `context` given the prompt.

        Args:
            context (list[bytes] | list[Token]): A sequence of byte tokens or Token objects.

        Returns:
            (float): The log probability of `context`.
        """
        if not context:
            return 0

        context_ids = self.encode_tokens(context)
        return await self._log_probability(context_ids)

    async def _log_probability(self, context_ids):
        prefixes = [self.prompt_ids + context_ids[:i] for i in range(len(context_ids))]
        log_ps = self._maybe_temper(
            await self.model.batch_next_token_logprobs(prefixes)
        )
        target_ids = torch.tensor(context_ids, device=log_ps.device)
        with torch.no_grad():
            token_logprobs = torch.gather(log_ps, 1, target_ids.unsqueeze(1))
            total_logprob = token_logprobs.sum().item()

        return total_logprob

    def _maybe_temper(self, logps):
        if self.temperature == 1:
            return logps
        return torch.log_softmax(logps / self.temperature, dim=-1)

    async def prefix(self, context):
        """
        Compute the log probability of `context` given the prompt.

        Args:
            context (list[bytes] | list[Token]): A sequence of byte tokens or Token objects.

        Returns:
            (float): The log probability of `context`.
        """
        return await self.log_probability(context)

    async def complete(self, context):
        """
        Compute the log probability of `context` and the eos tokens given the prompt.

        If the model has multiple eos tokens, their probabilities will be summed.

        Args:
            context (list[bytes] | list[Token]): A sequence of byte tokens or Token objects.

        Returns:
            (float): The log probability of the context.
        """
        context_ids = self.encode_tokens(context)
        logp_context = await self._log_probability(context_ids)
        logp_next = self._maybe_temper(
            await self.model.next_token_logprobs(self.prompt_ids + context_ids)
        )
        logp_eos = torch.logsumexp(logp_next[self.token_maps.eos_idxs], dim=0).item()
        return logp_context + logp_eos

    def _process_logw_next(self, logw_next):
        """Process the log probabilities for the next tokens.

        This function rearranges the log probabilities such that the end-of-sequence (EOS) token's log probability
        is the sum of the log probabilities of `self.eos_byte_strings`.

        Args:
            logw_next (torch.tensor): The log probabilities for the next tokens.

        Returns:
            (LazyWeights): Processed log probabilities for the next tokens.
        """
        # This is ugly, but it's useful for all potentials to adhere to the convention
        # of keeping the EOS token at the end of the weights array.

        # Cache eos_idxs_tensor and non_eos_indices on first use
        if (
            not hasattr(self, "_eos_idxs_tensor")
            or not hasattr(self, "_non_eos_indices")
            or self._eos_idxs_tensor.device != logw_next.device
        ):
            self._eos_idxs_tensor = torch.tensor(
                self.token_maps.eos_idxs, device=logw_next.device
            )
            all_indices = torch.arange(
                len(self.token_maps.decode), device=logw_next.device
            )
            self._non_eos_indices = all_indices[
                ~torch.isin(all_indices, self._eos_idxs_tensor)
            ]

        # The model may produce fewer logits than len(token_maps.decode) when
        # the tokenizer has added tokens beyond the model's embedding matrix
        # (e.g. Gemma's <image_soft_token>). Pad with -inf so these tokens
        # are unscorable but still present in the vocabulary.
        # We assert that HF models always produce logits for token indices
        # 0..vocab_size-1, and added tokens are at indices >= vocab_size.
        n_decode = len(self.token_maps.decode)
        n_logits = len(logw_next)
        if n_logits < n_decode:
            # Verify (once) that token IDs in the model's logit range are
            # contiguous 0..n_logits-1, so padding the tail is safe.
            if not hasattr(self, "_logit_padding_verified"):
                for i in range(n_logits):
                    if self.token_maps.decode[i].token_id != i:
                        raise ValueError(
                            f"Token ID / index mismatch at position {i}: "
                            f"decode[{i}].token_id={self.token_maps.decode[i].token_id}. "
                            f"Padding assumes added tokens are at indices >= vocab_size."
                        )
                self._logit_padding_verified = True
            pad = torch.full(
                (n_decode - n_logits,),
                float("-inf"),
                dtype=logw_next.dtype,
                device=logw_next.device,
            )
            logw_next = torch.cat([logw_next, pad])

        logw_next = logw_next[:n_decode]
        logw_next = logw_next.log_softmax(dim=0)
        _logw_next = torch.full(
            (len(self.vocab) + 1,),
            float("-inf"),
            dtype=logw_next.dtype,
            device=logw_next.device,
        )
        _logw_next[: len(self.vocab)] = logw_next[self._non_eos_indices]

        # Special case: if only one EOS idx, just assign directly (avoids cost of logsumexp)
        if self._eos_idxs_tensor.numel() == 1:
            _logw_next[-1] = logw_next[self._eos_idxs_tensor]
        else:
            _logw_next[-1] = torch.logsumexp(logw_next[self._eos_idxs_tensor], dim=0)

        return self.make_lazy_weights(_logw_next.float().cpu().numpy())

    async def logw_next(self, context):
        """Get log probabilities for next tokens given the prompt and `context`.

        Args:
            context (list[bytes] | list[Token]): A sequence of byte tokens or Token objects.

        Returns:
            (LazyWeights): Log probabilities for next tokens and EOS. Keys are Token objects.
        """
        context_ids = self.encode_tokens(context)
        logw_next = self._maybe_temper(
            await self.model.next_token_logprobs(self.prompt_ids + context_ids)
        )
        return self._process_logw_next(logw_next)

    async def batch_logw_next(self, contexts):
        """Get log probabilities for next tokens given the prompt and `context`, for a batch of contexts.

        Args:
            contexts (list[list[bytes]] | list[list[Token]]): A list of sequences of byte tokens or Token objects.

        Returns:
            (list[LazyWeights]): Log probabilities for next tokens and EOS for each context. Keys are Token objects.
        """
        context_ids_batch = [self.encode_tokens(context) for context in contexts]
        logw_nexts = self._maybe_temper(
            await self.model.batch_next_token_logprobs(
                [self.prompt_ids + context_ids for context_ids in context_ids_batch]
            )
        )
        return [self._process_logw_next(logw_next) for logw_next in logw_nexts]

    def __repr__(self):
        return f"PromptedLLM(prompt={self.prompt!r})"

    def spawn(self, prompt_ids=None, eos_byte_strings=None, temperature=None, **kwargs):
        """
        Spawn a new PromptedLLM.

        Args:
            prompt_ids (optional, list[int]): The prompt to use as a prompt prefix for all input contexts.
                Defaults to the same prompt_ids as `self`.
            eos_byte_strings (optional, list[bytes]): A list of tokens to treat as end-of-sequence tokens.
                Defaults to the same eos_byte_strings as `self`.
            temperature (optional, float): The temperature with which to rescale logprobs.
                Defaults to the same temperature as `self`.

        Returns:
            (PromptedLLM): A new PromptedLLM with the same prompt and eos tokens.

        Note:
            This is a shallow copy. The new PromptedLLM will share the underlying AsyncLM instance.
        """
        eos_byte_strings = _compat_eos_tokens(eos_byte_strings, kwargs)
        prompt_ids = prompt_ids if prompt_ids is not None else self.prompt_ids.copy()
        temperature = temperature if temperature is not None else self.temperature

        if (eos_byte_strings is None) or (eos_byte_strings == self.token_maps.eos_byte_strings):
            # If the eos tokens don't change, we don't need to recompute the token maps or vocabulary.
            return PromptedLLM(
                self.model,
                prompt_ids=prompt_ids,
                temperature=temperature,
                token_maps=self.token_maps,
            )

        return PromptedLLM(
            self.model,
            prompt_ids=prompt_ids,
            eos_byte_strings=eos_byte_strings,
            temperature=temperature,
        )

    def spawn_new_eos(self, eos_byte_strings=None, **kwargs):
        """
        Create a new PromptedLLM with a different set of end-of-sequence tokens.

        Args:
            eos_byte_strings (list[bytes]): A list of tokens to treat as end-of-sequence tokens.

        Returns:
            (PromptedLLM): A new PromptedLLM with the specified end-of-sequence tokens.
                The new model will have the same prompt_ids as `self`.
        """
        eos_byte_strings = _compat_eos_tokens(eos_byte_strings, kwargs)
        return self.spawn(eos_byte_strings=eos_byte_strings)

    def to_autobatched(self):
        raise ValueError("PromptedLLMs are autobatched by default.")

__init__(llm, prompt_ids=None, eos_byte_strings=None, temperature=1.0, token_maps=None, **kwargs)

` Initializes the PromptedLLM potential.

Parameters:

Name Type Description Default
llm AsyncLM

The language model to use.

required
prompt_ids list[int]

Optional prompt to use as a prompt prefix for all input contexts. Must be a list of token IDs. Defaults to None. The prompt ids can be set post-init via prompt or prompt_ids.

None
eos_byte_strings list[bytes]

List of tokens to treat as end-of-sequence tokens. Defaults to the EOS token of the language model's tokenizer.

None
temperature float

The temperature to apply to the language model's logits. Defaults to 1.

1.0
token_maps TokenMappings

A precomputed mapping of tokens to token IDs with the potential's vocabulary. If provided, eos_byte_strings must not be provided. Defaults to None, which constructs a TokenMappings from the language model's byte vocabulary and the EOS tokens.

None
Source code in genlm/control/potential/built_in/llm.py
def __init__(
    self,
    llm,
    prompt_ids=None,
    eos_byte_strings=None,
    temperature=1.0,
    token_maps=None,
    **kwargs,
):
    """`
    Initializes the PromptedLLM potential.

    Args:
        llm (AsyncLM): The language model to use.
        prompt_ids (list[int], optional): Optional prompt to use as a prompt prefix for all input contexts.
            Must be a list of token IDs. Defaults to None. The prompt ids can be set post-init via `prompt` or `prompt_ids`.
        eos_byte_strings (list[bytes], optional): List of tokens to treat as end-of-sequence tokens.
            Defaults to the EOS token of the language model's tokenizer.
        temperature (float, optional): The temperature to apply to the language model's logits. Defaults to 1.
        token_maps (TokenMappings, optional): A precomputed mapping of tokens to token IDs with the potential's vocabulary.
            If provided, `eos_byte_strings` must not be provided. Defaults to None, which constructs a TokenMappings from the language model's byte vocabulary and the EOS tokens.
    """
    eos_byte_strings = _compat_eos_tokens(eos_byte_strings, kwargs)
    self.model = llm
    self.prompt_ids = prompt_ids or []
    self.temperature = temperature

    if token_maps is not None:
        if eos_byte_strings is not None:
            raise ValueError(
                "eos_byte_strings must not be provided when token_maps is provided."
            )
        self.token_maps = token_maps
    else:
        byte_vocab = self.model.byte_vocab
        default_eos = byte_vocab[self.model.tokenizer.eos_token_id].byte_string
        self.token_maps = TokenMappings.create(
            decode=byte_vocab,
            eos_byte_strings=eos_byte_strings or [default_eos],
        )

    super().__init__(vocabulary=self.token_maps.potential_vocab)

from_name(name, backend=None, eos_byte_strings=None, prompt_ids=None, temperature=1.0, **kwargs) classmethod

Create a PromptedLLM from a Hugging Face model name.

Parameters:

Name Type Description Default
name str

Name of the model to load

required
backend str

AsyncLM backend to use:

  • 'vllm' to instantiate an AsyncVirtualLM; ideal for GPU usage

  • 'hf' for an AsyncTransformer; ideal for CPU usage

  • 'mlx' for an AsyncMlxLM; ideal for Apple silicon usage

  • 'mock' for a MockAsyncLM; ideal for testing.

Defaults to 'vllm' if CUDA is available, otherwise 'hf'.

None
eos_byte_strings list[bytes]

List of tokens to treat as end-of-sequence tokens. Defaults to the EOS token of the language model's tokenizer.

None
prompt_ids list[int]

Optional prompt to use as a prompt prefix for all input contexts. Must be a list of token IDs. Defaults to None. The prompt ids can be set post-init via set_prompt_from_str or prompt_ids.

None
temperature float

The temperature to apply to the language model's logits. Defaults to 1.

1.0
**kwargs dict

Additional arguments passed to AsyncLM constructor

{}

Returns:

Type Description
PromptedLLM

An instance of PromptedLLM

Source code in genlm/control/potential/built_in/llm.py
@classmethod
def from_name(
    cls,
    name,
    backend=None,
    eos_byte_strings=None,
    prompt_ids=None,
    temperature=1.0,
    **kwargs,
):
    """Create a `PromptedLLM` from a Hugging Face model name.

    Args:
        name (str): Name of the model to load
        backend (str, optional): `AsyncLM` backend to use:\n
            * 'vllm' to instantiate an `AsyncVirtualLM`; ideal for GPU usage\n
            * 'hf' for an `AsyncTransformer`; ideal for CPU usage\n
            * 'mlx' for an `AsyncMlxLM`; ideal for Apple silicon usage\n
            * 'mock' for a `MockAsyncLM`; ideal for testing.\n
            Defaults to 'vllm' if CUDA is available, otherwise 'hf'.
        eos_byte_strings (list[bytes], optional): List of tokens to treat as end-of-sequence tokens.
            Defaults to the EOS token of the language model's tokenizer.
        prompt_ids (list[int], optional): Optional prompt to use as a prompt prefix for all input contexts.
            Must be a list of token IDs. Defaults to None. The prompt ids can be set post-init via `set_prompt_from_str` or `prompt_ids`.
        temperature (float, optional): The temperature to apply to the language model's logits. Defaults to 1.
        **kwargs (dict): Additional arguments passed to AsyncLM constructor

    Returns:
        (PromptedLLM): An instance of PromptedLLM
    """
    eos_byte_strings = _compat_eos_tokens(eos_byte_strings, kwargs)
    backend = backend or ("vllm" if torch.cuda.is_available() else "hf")
    model = load_model_by_name(name, backend=backend, **kwargs)
    return cls(
        model, prompt_ids=prompt_ids, eos_byte_strings=eos_byte_strings, temperature=temperature
    )

prompt property

Get the current prompt as Token objects.

Returns:

Type Description
list[Token] | None

The current prompt as Token objects, or None if no prompt_ids are set.

set_prompt_from_str(prompt_str)

Set the fixed prompt from a string.

Modifies prompt_ids to be the token IDs of the input prompt according to the language model's tokenizer.

Parameters:

Name Type Description Default
prompt_str str

The prompt to set.

required
Source code in genlm/control/potential/built_in/llm.py
def set_prompt_from_str(self, prompt_str):
    """Set the fixed prompt from a string.

    Modifies `prompt_ids` to be the token IDs of the input prompt according to the language model's tokenizer.

    Args:
        prompt_str (str): The prompt to set.
    """
    # TODO: Handle race condition where prompt_ids reset concurrently.
    if not isinstance(prompt_str, str):
        raise ValueError(
            f"Prompt must a string got {type(prompt_str)}. "
            f"To set the prompt from a list of token IDs, use prompt_ids."
        )

    if prompt_str.endswith(" "):
        warnings.warn(
            "Prompt ends with whitespace, which may affect tokenization. "
            "Consider removing trailing whitespace.",
            stacklevel=2,
        )

    self.prompt_ids = self.model.tokenizer.encode(prompt_str)

encode_tokens(tokens)

Encode a list of Token objects to token IDs.

Parameters:

Name Type Description Default
tokens list[Token]

List of Token objects

required

Returns:

Type Description
list[int]

A list of token IDs corresponding to the input tokens.

Raises:

Type Description
ValueError

If any token is not in the vocabulary.

Note

Passing bytes is deprecated. Use Token objects from llm.tokenize().

Source code in genlm/control/potential/built_in/llm.py
def encode_tokens(self, tokens):
    """Encode a list of Token objects to token IDs.

    Args:
        tokens (list[Token]): List of Token objects

    Returns:
        (list[int]): A list of token IDs corresponding to the input tokens.

    Raises:
        ValueError: If any token is not in the vocabulary.

    Note:
        Passing bytes is deprecated. Use Token objects from llm.tokenize().
    """
    if not tokens:
        return []

    result = []
    warned = False
    for item in tokens:
        if isinstance(item, Token):
            result.append(item.token_id)
        else:
            if not warned:
                warnings.warn(
                    "Passing bytes to encode_tokens is deprecated. "
                    "Use Token objects for precise control. ",
                    DeprecationWarning,
                    stacklevel=3,
                )
                warned = True
            token_id = self._find_token_id_for_bytes(item)
            if token_id is None:
                raise ValueError(f"Token {item!r} not in vocabulary")
            result.append(token_id)
    return result

decode_tokens(ids)

Decode a list of token IDs to Token objects.

Parameters:

Name Type Description Default
ids list[int]

A list of token IDs in the language model's vocabulary.

required

Returns:

Type Description
list[Token]

Token objects corresponding to the input token IDs.

Source code in genlm/control/potential/built_in/llm.py
def decode_tokens(self, ids):
    """
    Decode a list of token IDs to Token objects.

    Args:
        ids (list[int]): A list of token IDs in the language model's vocabulary.

    Returns:
        (list[Token]): Token objects corresponding to the input token IDs.
    """
    return [self.token_maps.decode[x] for x in ids]

tokenize(context_str)

Tokenize a string to a list of Token objects.

Uses the language model's tokenizer to map context_str to token IDs, then returns the corresponding Token objects.

Parameters:

Name Type Description Default
context_str str

A string to encode

required

Returns:

Type Description
list[Token]

Token objects corresponding to the input string.

Source code in genlm/control/potential/built_in/llm.py
def tokenize(self, context_str):
    """Tokenize a string to a list of Token objects.

    Uses the language model's tokenizer to map `context_str` to token IDs,
    then returns the corresponding Token objects.

    Args:
        context_str (str): A string to encode

    Returns:
        (list[Token]): Token objects corresponding to the input string.
    """
    return self.decode_tokens(self.model.tokenizer.encode(context_str))

log_probability(context) async

Compute the log probability of context given the prompt.

Parameters:

Name Type Description Default
context list[bytes] | list[Token]

A sequence of byte tokens or Token objects.

required

Returns:

Type Description
float

The log probability of context.

Source code in genlm/control/potential/built_in/llm.py
async def log_probability(self, context):
    """
    Compute the log probability of `context` given the prompt.

    Args:
        context (list[bytes] | list[Token]): A sequence of byte tokens or Token objects.

    Returns:
        (float): The log probability of `context`.
    """
    if not context:
        return 0

    context_ids = self.encode_tokens(context)
    return await self._log_probability(context_ids)

prefix(context) async

Compute the log probability of context given the prompt.

Parameters:

Name Type Description Default
context list[bytes] | list[Token]

A sequence of byte tokens or Token objects.

required

Returns:

Type Description
float

The log probability of context.

Source code in genlm/control/potential/built_in/llm.py
async def prefix(self, context):
    """
    Compute the log probability of `context` given the prompt.

    Args:
        context (list[bytes] | list[Token]): A sequence of byte tokens or Token objects.

    Returns:
        (float): The log probability of `context`.
    """
    return await self.log_probability(context)

complete(context) async

Compute the log probability of context and the eos tokens given the prompt.

If the model has multiple eos tokens, their probabilities will be summed.

Parameters:

Name Type Description Default
context list[bytes] | list[Token]

A sequence of byte tokens or Token objects.

required

Returns:

Type Description
float

The log probability of the context.

Source code in genlm/control/potential/built_in/llm.py
async def complete(self, context):
    """
    Compute the log probability of `context` and the eos tokens given the prompt.

    If the model has multiple eos tokens, their probabilities will be summed.

    Args:
        context (list[bytes] | list[Token]): A sequence of byte tokens or Token objects.

    Returns:
        (float): The log probability of the context.
    """
    context_ids = self.encode_tokens(context)
    logp_context = await self._log_probability(context_ids)
    logp_next = self._maybe_temper(
        await self.model.next_token_logprobs(self.prompt_ids + context_ids)
    )
    logp_eos = torch.logsumexp(logp_next[self.token_maps.eos_idxs], dim=0).item()
    return logp_context + logp_eos

logw_next(context) async

Get log probabilities for next tokens given the prompt and context.

Parameters:

Name Type Description Default
context list[bytes] | list[Token]

A sequence of byte tokens or Token objects.

required

Returns:

Type Description
LazyWeights

Log probabilities for next tokens and EOS. Keys are Token objects.

Source code in genlm/control/potential/built_in/llm.py
async def logw_next(self, context):
    """Get log probabilities for next tokens given the prompt and `context`.

    Args:
        context (list[bytes] | list[Token]): A sequence of byte tokens or Token objects.

    Returns:
        (LazyWeights): Log probabilities for next tokens and EOS. Keys are Token objects.
    """
    context_ids = self.encode_tokens(context)
    logw_next = self._maybe_temper(
        await self.model.next_token_logprobs(self.prompt_ids + context_ids)
    )
    return self._process_logw_next(logw_next)

batch_logw_next(contexts) async

Get log probabilities for next tokens given the prompt and context, for a batch of contexts.

Parameters:

Name Type Description Default
contexts list[list[bytes]] | list[list[Token]]

A list of sequences of byte tokens or Token objects.

required

Returns:

Type Description
list[LazyWeights]

Log probabilities for next tokens and EOS for each context. Keys are Token objects.

Source code in genlm/control/potential/built_in/llm.py
async def batch_logw_next(self, contexts):
    """Get log probabilities for next tokens given the prompt and `context`, for a batch of contexts.

    Args:
        contexts (list[list[bytes]] | list[list[Token]]): A list of sequences of byte tokens or Token objects.

    Returns:
        (list[LazyWeights]): Log probabilities for next tokens and EOS for each context. Keys are Token objects.
    """
    context_ids_batch = [self.encode_tokens(context) for context in contexts]
    logw_nexts = self._maybe_temper(
        await self.model.batch_next_token_logprobs(
            [self.prompt_ids + context_ids for context_ids in context_ids_batch]
        )
    )
    return [self._process_logw_next(logw_next) for logw_next in logw_nexts]

spawn(prompt_ids=None, eos_byte_strings=None, temperature=None, **kwargs)

Spawn a new PromptedLLM.

Parameters:

Name Type Description Default
prompt_ids (optional, list[int])

The prompt to use as a prompt prefix for all input contexts. Defaults to the same prompt_ids as self.

None
eos_byte_strings (optional, list[bytes])

A list of tokens to treat as end-of-sequence tokens. Defaults to the same eos_byte_strings as self.

None
temperature (optional, float)

The temperature with which to rescale logprobs. Defaults to the same temperature as self.

None

Returns:

Type Description
PromptedLLM

A new PromptedLLM with the same prompt and eos tokens.

Note

This is a shallow copy. The new PromptedLLM will share the underlying AsyncLM instance.

Source code in genlm/control/potential/built_in/llm.py
def spawn(self, prompt_ids=None, eos_byte_strings=None, temperature=None, **kwargs):
    """
    Spawn a new PromptedLLM.

    Args:
        prompt_ids (optional, list[int]): The prompt to use as a prompt prefix for all input contexts.
            Defaults to the same prompt_ids as `self`.
        eos_byte_strings (optional, list[bytes]): A list of tokens to treat as end-of-sequence tokens.
            Defaults to the same eos_byte_strings as `self`.
        temperature (optional, float): The temperature with which to rescale logprobs.
            Defaults to the same temperature as `self`.

    Returns:
        (PromptedLLM): A new PromptedLLM with the same prompt and eos tokens.

    Note:
        This is a shallow copy. The new PromptedLLM will share the underlying AsyncLM instance.
    """
    eos_byte_strings = _compat_eos_tokens(eos_byte_strings, kwargs)
    prompt_ids = prompt_ids if prompt_ids is not None else self.prompt_ids.copy()
    temperature = temperature if temperature is not None else self.temperature

    if (eos_byte_strings is None) or (eos_byte_strings == self.token_maps.eos_byte_strings):
        # If the eos tokens don't change, we don't need to recompute the token maps or vocabulary.
        return PromptedLLM(
            self.model,
            prompt_ids=prompt_ids,
            temperature=temperature,
            token_maps=self.token_maps,
        )

    return PromptedLLM(
        self.model,
        prompt_ids=prompt_ids,
        eos_byte_strings=eos_byte_strings,
        temperature=temperature,
    )

spawn_new_eos(eos_byte_strings=None, **kwargs)

Create a new PromptedLLM with a different set of end-of-sequence tokens.

Parameters:

Name Type Description Default
eos_byte_strings list[bytes]

A list of tokens to treat as end-of-sequence tokens.

None

Returns:

Type Description
PromptedLLM

A new PromptedLLM with the specified end-of-sequence tokens. The new model will have the same prompt_ids as self.

Source code in genlm/control/potential/built_in/llm.py
def spawn_new_eos(self, eos_byte_strings=None, **kwargs):
    """
    Create a new PromptedLLM with a different set of end-of-sequence tokens.

    Args:
        eos_byte_strings (list[bytes]): A list of tokens to treat as end-of-sequence tokens.

    Returns:
        (PromptedLLM): A new PromptedLLM with the specified end-of-sequence tokens.
            The new model will have the same prompt_ids as `self`.
    """
    eos_byte_strings = _compat_eos_tokens(eos_byte_strings, kwargs)
    return self.spawn(eos_byte_strings=eos_byte_strings)

ByteLLM

Bases: Potential

A potential representing a language model operating at the byte level using beam search.

ByteLLM wraps a language model and uses beam search to compute log probabilities over byte sequences. This enables constrained generation at the byte level while maintaining coherent token-level probabilities through adaptive token healing.

Parameters:

Name Type Description Default
llm Any

The language model to use (from genlm.backend).

required
beam_params BeamParams

Configuration for beam search, including beam width K, eos_byte_strings (list of EOS byte sequences), and healing parameters (heal, heal_max_backoff, heal_max_splits).

required
cache_size int

Maximum number of beam states to cache. Defaults to 1024.

1024
Example
from genlm.bytes import BeamParams
from genlm.control import ByteLLM

beam_params = BeamParams(K=5, eos_byte_strings=[b"<|endoftext|>"], heal=True)
async with ByteLLM.from_name("gpt2", beam_params) as byte_llm:
    byte_llm.set_prompt_from_str("Hello")
    logp = await byte_llm.prefix([b" ", b"w", b"o", b"r", b"l", b"d"])
Source code in genlm/control/potential/built_in/bytellm.py
class ByteLLM(Potential):
    """A potential representing a language model operating at the byte level using beam search.

    `ByteLLM` wraps a language model and uses beam search to compute log probabilities
    over byte sequences. This enables constrained generation at the byte level while
    maintaining coherent token-level probabilities through adaptive token healing.

    Args:
        llm: The language model to use (from `genlm.backend`).
        beam_params (BeamParams): Configuration for beam search, including beam width `K`,
            `eos_byte_strings` (list of EOS byte sequences), and healing parameters
            (`heal`, `heal_max_backoff`, `heal_max_splits`).
        cache_size (int): Maximum number of beam states to cache. Defaults to 1024.

    Example:
        ```python
        from genlm.bytes import BeamParams
        from genlm.control import ByteLLM

        beam_params = BeamParams(K=5, eos_byte_strings=[b"<|endoftext|>"], heal=True)
        async with ByteLLM.from_name("gpt2", beam_params) as byte_llm:
            byte_llm.set_prompt_from_str("Hello")
            logp = await byte_llm.prefix([b" ", b"w", b"o", b"r", b"l", b"d"])
        ```
    """

    def __init__(self, llm: Any, beam_params: BeamParams, cache_size: int = 1024):
        self.llm = llm
        self.beam_params = beam_params
        self.cache_size = cache_size
        vocab = [i.to_bytes(1, "big") for i in range(256)]
        super().__init__(vocabulary=vocab)
        # LRU cache of ByteBeamState keyed by full context bytes (prompt + context)
        self._beam_cache: OrderedDict[bytes, ByteBeamState] = OrderedDict()
        self._initial_beam = None
        self.prompt_bytes = b""
        # Fast path: cache last accessed beam for sequential access
        self._last_context = None
        self._last_beam = None

    @classmethod
    def from_name(
        cls,
        name,
        beam_params: BeamParams,
        backend=None,
        cache_size: int = 1024,
        **kwargs,
    ):
        backend = backend or ("vllm" if torch.cuda.is_available() else "hf")
        llm = load_model_by_name(name, backend=backend, **kwargs)
        return cls(llm, beam_params, cache_size=cache_size)

    def set_prompt_from_str(self, prompt_str: str):
        new_prompt_bytes = prompt_str.encode("utf-8")
        if new_prompt_bytes != self.prompt_bytes:
            self.prompt_bytes = new_prompt_bytes
            self._beam_cache.clear()
            self._initial_beam = None
            self._last_context = None
            self._last_beam = None

    async def _get_or_create_beam_for_context(self, context):
        context_bytes = b"".join(context)
        full_context_bytes = self.prompt_bytes + context_bytes

        # Fast path: exact cache hit
        if full_context_bytes in self._beam_cache:
            self._beam_cache.move_to_end(full_context_bytes)
            beam = self._beam_cache[full_context_bytes]
            self._last_context = full_context_bytes
            self._last_beam = beam
            return beam

        # Fast path: sequential access from last beam
        if (
            self._last_context is not None
            and full_context_bytes.startswith(self._last_context)
            and len(full_context_bytes) > len(self._last_context)
        ):
            best_prefix_bytes = self._last_context
            best_beam = self._last_beam
        else:
            # Search cache for longest prefix match
            best_prefix_bytes = b""
            best_beam = None
            for cached_prefix_bytes, cached_beam in self._beam_cache.items():
                if full_context_bytes.startswith(cached_prefix_bytes) and len(
                    cached_prefix_bytes
                ) > len(best_prefix_bytes):
                    best_prefix_bytes = cached_prefix_bytes
                    best_beam = cached_beam

            if best_beam is None:
                if self._initial_beam is None:
                    self._initial_beam = await ByteBeamState.initial(
                        self.llm, self.beam_params
                    )
                    if self.prompt_bytes:
                        self._initial_beam = await self._initial_beam.prefill(
                            self.prompt_bytes
                        )
                        self._cache_put(self.prompt_bytes, self._initial_beam)
                best_beam = self._initial_beam
                best_prefix_bytes = (
                    self.prompt_bytes
                    if full_context_bytes.startswith(self.prompt_bytes)
                    else b""
                )

        # Advance beam byte-by-byte
        remaining_bytes = full_context_bytes[len(best_prefix_bytes) :]
        current_beam = best_beam
        current_prefix_bytes = best_prefix_bytes

        for i, byte_val in enumerate(remaining_bytes):
            current_beam = current_beam.prune()
            current_beam = await (current_beam << byte_val)
            current_prefix_bytes += remaining_bytes[i : i + 1]

            if len(current_beam) == 0:
                raise ValueError(
                    f"Beam became empty at byte {byte_val} ({chr(byte_val) if 32 <= byte_val < 127 else f'0x{byte_val:02x}'}). "
                    f"Context so far: {current_prefix_bytes!r}. "
                    f"Consider enabling healing or increasing beam width K."
                )

            self._cache_put(current_prefix_bytes, current_beam)

        # Update last beam for fast sequential access
        self._last_context = full_context_bytes
        self._last_beam = current_beam

        return current_beam

    def _cache_put(self, key: bytes, beam: ByteBeamState):
        self._beam_cache[key] = beam
        self._beam_cache.move_to_end(key)
        while len(self._beam_cache) > self.cache_size:
            self._beam_cache.popitem(last=False)

    async def prefix(self, context):
        # Treat empty context as neutral (log 1 = 0), matching PromptedLLM semantics.
        # The prompt, if set, is incorporated into next-token distributions via the cached beam,
        # but does not contribute to the prefix weight of the empty context.
        if not context:
            return 0.0
        beam = await self._get_or_create_beam_for_context(context)
        base = self._initial_beam.logZ if self._initial_beam is not None else 0.0
        return beam.logZ - base

    async def complete(self, context):
        beam = await self._get_or_create_beam_for_context(context)
        logp_next = await beam.logp_next()
        # Assume logp_next.ps contains log-probs for 256 byte values plus EOS at the end.
        eos_logp = logp_next.ps[-1]
        base = self._initial_beam.logZ if self._initial_beam is not None else 0.0
        return (beam.logZ - base) + eos_logp

    async def logw_next(self, context):
        """Efficient next-token weights using the cached beam state.

        Uses the beam's next-token distribution directly instead of the
        default (slower) fallback that recomputes scores for each token.
        """
        beam = await self._get_or_create_beam_for_context(context)
        logp_next = await beam.logp_next()

        # Build weights over vocab_eos (256 bytes + EOS at the end)
        ps = np.asarray(logp_next.ps)
        logws = self.alloc_logws()
        v = len(self.vocab)
        logws[:v] = ps[:v]
        logws[-1] = ps[-1]
        return self.make_lazy_weights(logws)

    async def cleanup(self):
        """Cleans up resources used by the beam states.

        This method is called automatically when using ByteLLM as an async context manager.
        If not using a context manager, you should call this method manually when done.
        """
        if self._initial_beam:
            await self._initial_beam.cleanup()
        for beam in self._beam_cache.values():
            await beam.cleanup()
        self._beam_cache.clear()
        self._last_context = None
        self._last_beam = None

    async def __aenter__(self):
        """Async context manager entry."""
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Async context manager exit - ensures cleanup is called."""
        await self.cleanup()
        return False

logw_next(context) async

Efficient next-token weights using the cached beam state.

Uses the beam's next-token distribution directly instead of the default (slower) fallback that recomputes scores for each token.

Source code in genlm/control/potential/built_in/bytellm.py
async def logw_next(self, context):
    """Efficient next-token weights using the cached beam state.

    Uses the beam's next-token distribution directly instead of the
    default (slower) fallback that recomputes scores for each token.
    """
    beam = await self._get_or_create_beam_for_context(context)
    logp_next = await beam.logp_next()

    # Build weights over vocab_eos (256 bytes + EOS at the end)
    ps = np.asarray(logp_next.ps)
    logws = self.alloc_logws()
    v = len(self.vocab)
    logws[:v] = ps[:v]
    logws[-1] = ps[-1]
    return self.make_lazy_weights(logws)

cleanup() async

Cleans up resources used by the beam states.

This method is called automatically when using ByteLLM as an async context manager. If not using a context manager, you should call this method manually when done.

Source code in genlm/control/potential/built_in/bytellm.py
async def cleanup(self):
    """Cleans up resources used by the beam states.

    This method is called automatically when using ByteLLM as an async context manager.
    If not using a context manager, you should call this method manually when done.
    """
    if self._initial_beam:
        await self._initial_beam.cleanup()
    for beam in self._beam_cache.values():
        await beam.cleanup()
    self._beam_cache.clear()
    self._last_context = None
    self._last_beam = None

__aenter__() async

Async context manager entry.

Source code in genlm/control/potential/built_in/bytellm.py
async def __aenter__(self):
    """Async context manager entry."""
    return self

__aexit__(exc_type, exc_val, exc_tb) async

Async context manager exit - ensures cleanup is called.

Source code in genlm/control/potential/built_in/bytellm.py
async def __aexit__(self, exc_type, exc_val, exc_tb):
    """Async context manager exit - ensures cleanup is called."""
    await self.cleanup()
    return False

WCFG

Bases: Potential

A weighted context-free grammar potential.

This class wraps a genlm_grammar.CFG and provides methods for computing the log-weight of a sequence, the prefix log-weight of a sequence, and the log-weights of the next token given a sequence.

Source code in genlm/control/potential/built_in/wcfg.py
class WCFG(Potential):
    """
    A weighted context-free grammar potential.

    This class wraps a `genlm_grammar.CFG` and provides methods for computing the log-weight of a sequence,
    the prefix log-weight of a sequence, and the log-weights of the next token given a sequence.
    """

    def __init__(self, cfg):
        """
        Initialize the WCFG potential.

        Args:
            cfg (genlm_grammar.CFG): The context-free grammar configuration to use.
                The CFG must in the Float semiring.
        """
        # TODO: convert to LogSemiring to handle underflow
        if cfg.R is not Float:
            raise ValueError("cfg semiring must be Float")
        self.cfg = cfg  # cfg before prefix transform
        self.cfg_eos = _add_eos(cfg, EOS)  # augmented with eos
        self.model = Earley(self.cfg_eos.prefix_grammar)
        super().__init__(vocabulary=list(cfg.V))

    @classmethod
    def from_string(cls, grammar, to_bytes=True, **kwargs):
        """Create a WCFG from a string.

        Args:
            grammar (str): The string grammar specification to create the WCFG from.
            to_bytes (bool, optional): Whether to convert the WCFG terminals to indivudual bytes.
                Defaults to True.
            **kwargs (dict): Additional arguments passed to the WCFG constructor.

        Returns:
            (WCFG): The created WCFG.
        """
        cfg = CFG.from_string(grammar, Float)
        if to_bytes:
            cfg = cfg.to_bytes()
        return cls(cfg, **kwargs)

    async def complete(self, context):
        """
        Compute the log weight of `context` under the WCFG.

        For example, if the WCFG accepts "cat" and "car" with weights $w_{cat}$ and $w_{car}$:\n
        - `complete("c")` returns $-\\infty$ since this sequence is not accepted by the WCFG\n
        - `complete("cat")` returns $\\log(w_{cat})$\n
        - `complete("d")` returns $-\\infty$ since this sequence is not accepted by the WCFG

        Args:
            context (list): A sequence of tokens in the WCFG's alphabet.

        Returns:
            (float): The log weight of `context` under the WCFG.
        """
        w = self.model([*context, EOS])
        return np.log(w) if w > 0 else float("-inf")

    async def prefix(self, context):
        """
        Compute the log prefix weight of `context` under the WCFG.

        This corresponds to the log of the sum of the weights of all sequences with prefix `context`.

        For example, if the WCFG accepts "cat" and "car" with weights $w_{cat}$ and $w_{car}$:\n
        - `prefix("c")` returns $\\log(w_{cat} + w_{car})$\n
        - `prefix("cat")` returns $\\log(w_{cat})$\n
        - `prefix("d")` returns $-\\infty$ since the WCFG does not accept any sequences with prefix "d"

        Args:
            context (list): A sequence of tokens in the WCFG's alphabet.

        Returns:
            (float): The log prefix weight of `context` under the WCFG.
        """
        w = self.model(context)
        return np.log(w) if w > 0 else float("-inf")

    async def logw_next(self, context):
        """
        Compute the next token log weights given `context`.

        Args:
            context (list): A sequence of tokens in the WCFG's alphabet.

        Returns:
            (LazyWeights): The log weights for the next tokens and EOS given `context`.
        """
        ws = self.model.next_token_weights(self.model.chart(context))
        ws = ws.trim().normalize()

        ws_array = np.array([ws[x] for x in self.vocab_eos])
        mask = ws_array > 0
        log_ws = np.full_like(ws_array, float("-inf"), dtype=np.float64)
        log_ws[mask] = np.log(ws_array[mask])

        return self.make_lazy_weights(log_ws)

    def clear_cache(self):
        """Clear the internal cache of the parser."""
        self.model.clear_cache()

    def __repr__(self):
        return f"WCFG(cfg={self.cfg!r})"

    def _repr_html_(self):
        return self.cfg._repr_html_()

    def spawn(self):
        """Spawn a new WCFG."""
        return WCFG(self.cfg)

__init__(cfg)

Initialize the WCFG potential.

Parameters:

Name Type Description Default
cfg CFG

The context-free grammar configuration to use. The CFG must in the Float semiring.

required
Source code in genlm/control/potential/built_in/wcfg.py
def __init__(self, cfg):
    """
    Initialize the WCFG potential.

    Args:
        cfg (genlm_grammar.CFG): The context-free grammar configuration to use.
            The CFG must in the Float semiring.
    """
    # TODO: convert to LogSemiring to handle underflow
    if cfg.R is not Float:
        raise ValueError("cfg semiring must be Float")
    self.cfg = cfg  # cfg before prefix transform
    self.cfg_eos = _add_eos(cfg, EOS)  # augmented with eos
    self.model = Earley(self.cfg_eos.prefix_grammar)
    super().__init__(vocabulary=list(cfg.V))

from_string(grammar, to_bytes=True, **kwargs) classmethod

Create a WCFG from a string.

Parameters:

Name Type Description Default
grammar str

The string grammar specification to create the WCFG from.

required
to_bytes bool

Whether to convert the WCFG terminals to indivudual bytes. Defaults to True.

True
**kwargs dict

Additional arguments passed to the WCFG constructor.

{}

Returns:

Type Description
WCFG

The created WCFG.

Source code in genlm/control/potential/built_in/wcfg.py
@classmethod
def from_string(cls, grammar, to_bytes=True, **kwargs):
    """Create a WCFG from a string.

    Args:
        grammar (str): The string grammar specification to create the WCFG from.
        to_bytes (bool, optional): Whether to convert the WCFG terminals to indivudual bytes.
            Defaults to True.
        **kwargs (dict): Additional arguments passed to the WCFG constructor.

    Returns:
        (WCFG): The created WCFG.
    """
    cfg = CFG.from_string(grammar, Float)
    if to_bytes:
        cfg = cfg.to_bytes()
    return cls(cfg, **kwargs)

complete(context) async

Compute the log weight of context under the WCFG.

For example, if the WCFG accepts "cat" and "car" with weights \(w_{cat}\) and \(w_{car}\):

  • complete("c") returns \(-\infty\) since this sequence is not accepted by the WCFG

  • complete("cat") returns \(\log(w_{cat})\)

  • complete("d") returns \(-\infty\) since this sequence is not accepted by the WCFG

Parameters:

Name Type Description Default
context list

A sequence of tokens in the WCFG's alphabet.

required

Returns:

Type Description
float

The log weight of context under the WCFG.

Source code in genlm/control/potential/built_in/wcfg.py
async def complete(self, context):
    """
    Compute the log weight of `context` under the WCFG.

    For example, if the WCFG accepts "cat" and "car" with weights $w_{cat}$ and $w_{car}$:\n
    - `complete("c")` returns $-\\infty$ since this sequence is not accepted by the WCFG\n
    - `complete("cat")` returns $\\log(w_{cat})$\n
    - `complete("d")` returns $-\\infty$ since this sequence is not accepted by the WCFG

    Args:
        context (list): A sequence of tokens in the WCFG's alphabet.

    Returns:
        (float): The log weight of `context` under the WCFG.
    """
    w = self.model([*context, EOS])
    return np.log(w) if w > 0 else float("-inf")

prefix(context) async

Compute the log prefix weight of context under the WCFG.

This corresponds to the log of the sum of the weights of all sequences with prefix context.

For example, if the WCFG accepts "cat" and "car" with weights \(w_{cat}\) and \(w_{car}\):

  • prefix("c") returns \(\log(w_{cat} + w_{car})\)

  • prefix("cat") returns \(\log(w_{cat})\)

  • prefix("d") returns \(-\infty\) since the WCFG does not accept any sequences with prefix "d"

Parameters:

Name Type Description Default
context list

A sequence of tokens in the WCFG's alphabet.

required

Returns:

Type Description
float

The log prefix weight of context under the WCFG.

Source code in genlm/control/potential/built_in/wcfg.py
async def prefix(self, context):
    """
    Compute the log prefix weight of `context` under the WCFG.

    This corresponds to the log of the sum of the weights of all sequences with prefix `context`.

    For example, if the WCFG accepts "cat" and "car" with weights $w_{cat}$ and $w_{car}$:\n
    - `prefix("c")` returns $\\log(w_{cat} + w_{car})$\n
    - `prefix("cat")` returns $\\log(w_{cat})$\n
    - `prefix("d")` returns $-\\infty$ since the WCFG does not accept any sequences with prefix "d"

    Args:
        context (list): A sequence of tokens in the WCFG's alphabet.

    Returns:
        (float): The log prefix weight of `context` under the WCFG.
    """
    w = self.model(context)
    return np.log(w) if w > 0 else float("-inf")

logw_next(context) async

Compute the next token log weights given context.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the WCFG's alphabet.

required

Returns:

Type Description
LazyWeights

The log weights for the next tokens and EOS given context.

Source code in genlm/control/potential/built_in/wcfg.py
async def logw_next(self, context):
    """
    Compute the next token log weights given `context`.

    Args:
        context (list): A sequence of tokens in the WCFG's alphabet.

    Returns:
        (LazyWeights): The log weights for the next tokens and EOS given `context`.
    """
    ws = self.model.next_token_weights(self.model.chart(context))
    ws = ws.trim().normalize()

    ws_array = np.array([ws[x] for x in self.vocab_eos])
    mask = ws_array > 0
    log_ws = np.full_like(ws_array, float("-inf"), dtype=np.float64)
    log_ws[mask] = np.log(ws_array[mask])

    return self.make_lazy_weights(log_ws)

clear_cache()

Clear the internal cache of the parser.

Source code in genlm/control/potential/built_in/wcfg.py
def clear_cache(self):
    """Clear the internal cache of the parser."""
    self.model.clear_cache()

spawn()

Spawn a new WCFG.

Source code in genlm/control/potential/built_in/wcfg.py
def spawn(self):
    """Spawn a new WCFG."""
    return WCFG(self.cfg)

BoolCFG

Bases: Potential

BoolCFG represents a boolean context-free grammar.

Source code in genlm/control/potential/built_in/wcfg.py
class BoolCFG(Potential):
    """BoolCFG represents a boolean context-free grammar."""

    def __init__(self, cfg):
        if cfg.R != Boolean:
            cfg = cfg.map_values(lambda x: Boolean(x > 0), Boolean)
        self.cfg = cfg  # cfg before prefix transform
        self.cfg_eos = _add_eos(cfg, EOS)  # augmented with eos
        self.model = Earley(self.cfg_eos.prefix_grammar)
        super().__init__(vocabulary=list(cfg.V))

    @classmethod
    def from_lark(cls, lark_string, charset="core"):
        """
        Create a BoolCFG instance from a Lark grammar string.

        The output grammar will be defined at the byte-level.

        Args:
            lark_string (str): The Lark grammar string to parse. See Lark documentation for correct syntax.
            charset (str): The character set to use. Defaults to "core".
                See `genlm-grammar` documentation for more details.

        Returns:
            (BoolCFG): An instance of BoolCFG created from the provided Lark grammar.
        """
        byte_cfg = LarkStuff(lark_string).byte_cfg(charset=charset)
        return cls(byte_cfg)

    async def complete(self, context):
        """
        Checks whether the context is accepted by the CFG.

        Args:
            context (list): A sequence of tokens in the CFG's alphabet.

        Returns:
            (float): Log weight for whether `context` is accepted by the CFG.
        """
        w = self.model([*context, EOS])
        return 0 if w.score else float("-inf")

    async def prefix(self, context):
        """
        Checks whether `context` is accepted as a prefix by the CFG, i.e.,
        whether there exists a completion to `context` that is accepted by the CFG.

        Args:
            context (list): A sequence of tokens in the CFG's alphabet.

        Returns:
            (float): Log weight for whether `context` is accepted as a prefix by the CFG.
        """
        if not context:  # FIX: this is a hack to handle the empty string because genlm-grammar doesn't support it
            return 0
        w = self.model(context)
        return 0 if w.score else float("-inf")

    async def logw_next(self, context):
        """
        Compute the next token log weights given `context`.

        Args:
            context (list): A sequence of tokens in the CFG's alphabet.

        Returns:
            (LazyWeights): The log weights for the next tokens and EOS given `context`.
        """
        ws = self.model.next_token_weights(self.model.chart(context))
        log_ws = np.array([0 if ws[x].score else float("-inf") for x in self.vocab_eos])
        return self.make_lazy_weights(log_ws)

    async def batch_logw_next(self, contexts):
        """
        Batch version of `logw_next`.

        Args:
            contexts (list): A list of sequences of tokens in the CFG's alphabet.

        Returns:
            (list): A list of log-weights for next token, one per context.
        """
        Ws = []
        for context in contexts:
            ws = self.model.next_token_weights(self.model.chart(context))
            log_ws = np.array(
                [0 if ws[x].score else float("-inf") for x in self.vocab_eos]
            )
            Ws.append(self.make_lazy_weights(log_ws))
        return Ws

    def spawn(self):
        """Spawn a new BoolCFG."""
        return BoolCFG(self.cfg)

    def clear_cache(self):
        """Clear the internal cache of the parser."""
        self.model.clear_cache()

    def __repr__(self):
        return f"BoolCFG(cfg={self.cfg!r})"

    def _repr_html_(self):
        return self.cfg._repr_html_()

from_lark(lark_string, charset='core') classmethod

Create a BoolCFG instance from a Lark grammar string.

The output grammar will be defined at the byte-level.

Parameters:

Name Type Description Default
lark_string str

The Lark grammar string to parse. See Lark documentation for correct syntax.

required
charset str

The character set to use. Defaults to "core". See genlm-grammar documentation for more details.

'core'

Returns:

Type Description
BoolCFG

An instance of BoolCFG created from the provided Lark grammar.

Source code in genlm/control/potential/built_in/wcfg.py
@classmethod
def from_lark(cls, lark_string, charset="core"):
    """
    Create a BoolCFG instance from a Lark grammar string.

    The output grammar will be defined at the byte-level.

    Args:
        lark_string (str): The Lark grammar string to parse. See Lark documentation for correct syntax.
        charset (str): The character set to use. Defaults to "core".
            See `genlm-grammar` documentation for more details.

    Returns:
        (BoolCFG): An instance of BoolCFG created from the provided Lark grammar.
    """
    byte_cfg = LarkStuff(lark_string).byte_cfg(charset=charset)
    return cls(byte_cfg)

complete(context) async

Checks whether the context is accepted by the CFG.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the CFG's alphabet.

required

Returns:

Type Description
float

Log weight for whether context is accepted by the CFG.

Source code in genlm/control/potential/built_in/wcfg.py
async def complete(self, context):
    """
    Checks whether the context is accepted by the CFG.

    Args:
        context (list): A sequence of tokens in the CFG's alphabet.

    Returns:
        (float): Log weight for whether `context` is accepted by the CFG.
    """
    w = self.model([*context, EOS])
    return 0 if w.score else float("-inf")

prefix(context) async

Checks whether context is accepted as a prefix by the CFG, i.e., whether there exists a completion to context that is accepted by the CFG.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the CFG's alphabet.

required

Returns:

Type Description
float

Log weight for whether context is accepted as a prefix by the CFG.

Source code in genlm/control/potential/built_in/wcfg.py
async def prefix(self, context):
    """
    Checks whether `context` is accepted as a prefix by the CFG, i.e.,
    whether there exists a completion to `context` that is accepted by the CFG.

    Args:
        context (list): A sequence of tokens in the CFG's alphabet.

    Returns:
        (float): Log weight for whether `context` is accepted as a prefix by the CFG.
    """
    if not context:  # FIX: this is a hack to handle the empty string because genlm-grammar doesn't support it
        return 0
    w = self.model(context)
    return 0 if w.score else float("-inf")

logw_next(context) async

Compute the next token log weights given context.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the CFG's alphabet.

required

Returns:

Type Description
LazyWeights

The log weights for the next tokens and EOS given context.

Source code in genlm/control/potential/built_in/wcfg.py
async def logw_next(self, context):
    """
    Compute the next token log weights given `context`.

    Args:
        context (list): A sequence of tokens in the CFG's alphabet.

    Returns:
        (LazyWeights): The log weights for the next tokens and EOS given `context`.
    """
    ws = self.model.next_token_weights(self.model.chart(context))
    log_ws = np.array([0 if ws[x].score else float("-inf") for x in self.vocab_eos])
    return self.make_lazy_weights(log_ws)

batch_logw_next(contexts) async

Batch version of logw_next.

Parameters:

Name Type Description Default
contexts list

A list of sequences of tokens in the CFG's alphabet.

required

Returns:

Type Description
list

A list of log-weights for next token, one per context.

Source code in genlm/control/potential/built_in/wcfg.py
async def batch_logw_next(self, contexts):
    """
    Batch version of `logw_next`.

    Args:
        contexts (list): A list of sequences of tokens in the CFG's alphabet.

    Returns:
        (list): A list of log-weights for next token, one per context.
    """
    Ws = []
    for context in contexts:
        ws = self.model.next_token_weights(self.model.chart(context))
        log_ws = np.array(
            [0 if ws[x].score else float("-inf") for x in self.vocab_eos]
        )
        Ws.append(self.make_lazy_weights(log_ws))
    return Ws

spawn()

Spawn a new BoolCFG.

Source code in genlm/control/potential/built_in/wcfg.py
def spawn(self):
    """Spawn a new BoolCFG."""
    return BoolCFG(self.cfg)

clear_cache()

Clear the internal cache of the parser.

Source code in genlm/control/potential/built_in/wcfg.py
def clear_cache(self):
    """Clear the internal cache of the parser."""
    self.model.clear_cache()

WFSA

Bases: Potential

A weighted finite state automaton (WFSA) potential.

This class wraps a genlm_grammar.WFSA and provides methods for computing the log-weight of a context, the prefix log-weight of a context, and the log-weights of the next token given a context.

Attributes:

Name Type Description
wfsa WFSA

The weighted finite state automaton used for potential calculations.

Source code in genlm/control/potential/built_in/wfsa.py
class WFSA(Potential):
    """
    A weighted finite state automaton (WFSA) potential.

    This class wraps a `genlm_grammar.WFSA` and provides methods for computing the log-weight of a context,
    the prefix log-weight of a context, and the log-weights of the next token given a context.

    Attributes:
        wfsa (genlm_grammar.WFSA): The weighted finite state automaton used for potential calculations.
    """

    def __init__(self, wfsa):
        """
        Initializes the WFSA potential.

        Args:
            wfsa (genlm_grammar.WFSA): The weighted finite state automaton.

        Raises:
            ValueError: If the semiring of the provided WFSA is not Float or Log.

        Note:
            The WFSA will be converted to the Log semiring to avoid underflow if the semiring is Float.
        """
        if wfsa.R not in (Float, Log):
            raise ValueError(f"Unsupported semiring: {wfsa.R}")

        if wfsa.R is Float:
            self.wfsa = self._convert_to_log(wfsa)
        else:
            self.wfsa = wfsa

        self.cache = {(): self.wfsa.epsremove.start}
        super().__init__(vocabulary=list(self.wfsa.alphabet))

    @classmethod
    def from_regex(cls, pattern, charset=None, to_bytes=True):
        """
        Create a WFSA from a regex pattern.

        Args:
            pattern (str): The regex pattern to convert into a WFSA.
            charset (set): The character set to use for negative character classes.
                Defaults to characters in string.printable.
            to_bytes (bool): Whether to convert the WFSA transitions to bytes.
                Defaults to True. When set to False, the WFSA transitions will be strings.

        Returns:
            (WFSA): An instance of the WFSA class.

        Note:
            The transition weights are automatically normalized to form a probability distribution.
            For each state, the weights of all outgoing transitions (including final state transitions)
            sum to 1.0. This means if a state has n possible transitions, each transition will have
            weight 1/n. To create a WFSA from a regex with non-probabilistic transitions, use `BoolFSA`.
        """
        charset = charset or set(string.printable)
        wfsa = interegular_to_wfsa(pattern, charset=charset)
        if to_bytes:
            wfsa = wfsa.to_bytes()
        return cls(wfsa=wfsa)

    @staticmethod
    def _convert_to_log(wfsa):
        """Convert a WFSA from the Float semiring to the Log semiring."""
        assert wfsa.R is Float
        assert isinstance(wfsa, BaseWFSA)
        new = BaseWFSA(Log)

        for i, w in wfsa.I:
            new.add_I(i, Log(np.log(w)))

        for i, w in wfsa.F:
            new.add_F(i, Log(np.log(w)))

        for i, a, j, w in wfsa.arcs():
            new.add_arc(i, a, j, Log(np.log(w)))

        return new

    def _consume(self, bs):
        # XXX implement cache eviction
        bs = tuple(bs)

        try:
            return self.cache[bs]
        except KeyError:
            pass

        wfsa = self.wfsa.epsremove
        curr = wfsa.R.chart()
        prev = self._consume(bs[:-1])
        for i in prev:
            for j, w in wfsa.arcs(i, bs[-1]):
                curr[j] += prev[i] * w

        self.cache[bs] = curr

        return curr

    async def complete(self, context):
        """
        Computes the log weight of the context under the weighted language represented by the WFSA.

        For example, if the WFSA accepts "cat" and "car" with weights $w_{cat}$ and $w_{car}$:\n
        - `complete("c")` returns $-\\infty$ since this sequence is not accepted by the WFSA\n
        - `complete("cat")` returns $\\log(w_{cat})$\n
        - `complete("d")` returns $-\\infty$ since this sequence is not accepted by the WFSA

        Args:
            context (list): A sequence of tokens in the WFSA's alphabet.

        Returns:
            (float): Log weight of context under the WFSA.
        """
        # TODO: optimize to use _consume cache
        return self.wfsa(context).score

    def _prefix(self, context):
        curr = self._consume(context)

        if not curr:
            return float("-inf"), curr

        bkwd = self.wfsa.epsremove.backward
        log_ctx_w = logsumexp([(curr[i] * bkwd[i]).score for i in curr])

        if np.isnan(log_ctx_w):
            return float("-inf"), curr

        return log_ctx_w, curr

    async def prefix(self, context):
        """
        Computes the prefix log weight of `context` under the WFSA.

        This corresponds to the log of the sum of the weights of all sequences with prefix `context`.

        For example, if the WFSA accepts "cat" and "car" with weights $w_{cat}$ and $w_{car}$:\n
        - `prefix("c")` returns $\\log(w_{cat} + w_{car})$\n
        - `prefix("ca")` returns $\\log(w_{cat})$\n
        - `prefix("d")` returns $-\\infty$ since the WFSA does not accept any sequences with prefix "d"

        Args:
            context (list): A sequence of tokens in the WFSA's alphabet.

        Returns:
            (float): Log weight of `context` as a prefix under the WFSA.
        """
        return self._prefix(context)[0]

    async def logw_next(self, context):
        """Returns next token log weights given `context`.

        Args:
            context (list): A sequence of tokens in the WFSA's alphabet.

        Returns:
            (LazyWeights): Log-weights for next token and EOS.
        """
        log_ctx_w, curr = self._prefix(context)

        if log_ctx_w == float("-inf"):
            raise ValueError(f"Context {context!r} has zero weight.")

        bkwd = self.wfsa.epsremove.backward

        ws = self.wfsa.R.chart()
        for i in curr:
            for b, j, w in self.wfsa.epsremove.arcs(i=i):
                ws[b] += curr[i] * w * bkwd[j]

        ws[self.eos] = self.wfsa.R.zero
        for j, w in self.wfsa.epsremove.F:
            ws[self.eos] += curr[j] * w

        log_ws = np.array([ws[b].score for b in self.vocab_eos]) - log_ctx_w

        return self.make_lazy_weights(log_ws)

    def _repr_svg_(self):
        return self.wfsa._repr_svg_()

    def __repr__(self):
        return f"WFSA(wfsa={self.wfsa!r})"

    def spawn(self):
        cls = type(self)
        return cls(wfsa=self.wfsa)

    def clear_cache(self):
        self.cache = {(): self.wfsa.epsremove.start}

__init__(wfsa)

Initializes the WFSA potential.

Parameters:

Name Type Description Default
wfsa WFSA

The weighted finite state automaton.

required

Raises:

Type Description
ValueError

If the semiring of the provided WFSA is not Float or Log.

Note

The WFSA will be converted to the Log semiring to avoid underflow if the semiring is Float.

Source code in genlm/control/potential/built_in/wfsa.py
def __init__(self, wfsa):
    """
    Initializes the WFSA potential.

    Args:
        wfsa (genlm_grammar.WFSA): The weighted finite state automaton.

    Raises:
        ValueError: If the semiring of the provided WFSA is not Float or Log.

    Note:
        The WFSA will be converted to the Log semiring to avoid underflow if the semiring is Float.
    """
    if wfsa.R not in (Float, Log):
        raise ValueError(f"Unsupported semiring: {wfsa.R}")

    if wfsa.R is Float:
        self.wfsa = self._convert_to_log(wfsa)
    else:
        self.wfsa = wfsa

    self.cache = {(): self.wfsa.epsremove.start}
    super().__init__(vocabulary=list(self.wfsa.alphabet))

from_regex(pattern, charset=None, to_bytes=True) classmethod

Create a WFSA from a regex pattern.

Parameters:

Name Type Description Default
pattern str

The regex pattern to convert into a WFSA.

required
charset set

The character set to use for negative character classes. Defaults to characters in string.printable.

None
to_bytes bool

Whether to convert the WFSA transitions to bytes. Defaults to True. When set to False, the WFSA transitions will be strings.

True

Returns:

Type Description
WFSA

An instance of the WFSA class.

Note

The transition weights are automatically normalized to form a probability distribution. For each state, the weights of all outgoing transitions (including final state transitions) sum to 1.0. This means if a state has n possible transitions, each transition will have weight 1/n. To create a WFSA from a regex with non-probabilistic transitions, use BoolFSA.

Source code in genlm/control/potential/built_in/wfsa.py
@classmethod
def from_regex(cls, pattern, charset=None, to_bytes=True):
    """
    Create a WFSA from a regex pattern.

    Args:
        pattern (str): The regex pattern to convert into a WFSA.
        charset (set): The character set to use for negative character classes.
            Defaults to characters in string.printable.
        to_bytes (bool): Whether to convert the WFSA transitions to bytes.
            Defaults to True. When set to False, the WFSA transitions will be strings.

    Returns:
        (WFSA): An instance of the WFSA class.

    Note:
        The transition weights are automatically normalized to form a probability distribution.
        For each state, the weights of all outgoing transitions (including final state transitions)
        sum to 1.0. This means if a state has n possible transitions, each transition will have
        weight 1/n. To create a WFSA from a regex with non-probabilistic transitions, use `BoolFSA`.
    """
    charset = charset or set(string.printable)
    wfsa = interegular_to_wfsa(pattern, charset=charset)
    if to_bytes:
        wfsa = wfsa.to_bytes()
    return cls(wfsa=wfsa)

complete(context) async

Computes the log weight of the context under the weighted language represented by the WFSA.

For example, if the WFSA accepts "cat" and "car" with weights \(w_{cat}\) and \(w_{car}\):

  • complete("c") returns \(-\infty\) since this sequence is not accepted by the WFSA

  • complete("cat") returns \(\log(w_{cat})\)

  • complete("d") returns \(-\infty\) since this sequence is not accepted by the WFSA

Parameters:

Name Type Description Default
context list

A sequence of tokens in the WFSA's alphabet.

required

Returns:

Type Description
float

Log weight of context under the WFSA.

Source code in genlm/control/potential/built_in/wfsa.py
async def complete(self, context):
    """
    Computes the log weight of the context under the weighted language represented by the WFSA.

    For example, if the WFSA accepts "cat" and "car" with weights $w_{cat}$ and $w_{car}$:\n
    - `complete("c")` returns $-\\infty$ since this sequence is not accepted by the WFSA\n
    - `complete("cat")` returns $\\log(w_{cat})$\n
    - `complete("d")` returns $-\\infty$ since this sequence is not accepted by the WFSA

    Args:
        context (list): A sequence of tokens in the WFSA's alphabet.

    Returns:
        (float): Log weight of context under the WFSA.
    """
    # TODO: optimize to use _consume cache
    return self.wfsa(context).score

prefix(context) async

Computes the prefix log weight of context under the WFSA.

This corresponds to the log of the sum of the weights of all sequences with prefix context.

For example, if the WFSA accepts "cat" and "car" with weights \(w_{cat}\) and \(w_{car}\):

  • prefix("c") returns \(\log(w_{cat} + w_{car})\)

  • prefix("ca") returns \(\log(w_{cat})\)

  • prefix("d") returns \(-\infty\) since the WFSA does not accept any sequences with prefix "d"

Parameters:

Name Type Description Default
context list

A sequence of tokens in the WFSA's alphabet.

required

Returns:

Type Description
float

Log weight of context as a prefix under the WFSA.

Source code in genlm/control/potential/built_in/wfsa.py
async def prefix(self, context):
    """
    Computes the prefix log weight of `context` under the WFSA.

    This corresponds to the log of the sum of the weights of all sequences with prefix `context`.

    For example, if the WFSA accepts "cat" and "car" with weights $w_{cat}$ and $w_{car}$:\n
    - `prefix("c")` returns $\\log(w_{cat} + w_{car})$\n
    - `prefix("ca")` returns $\\log(w_{cat})$\n
    - `prefix("d")` returns $-\\infty$ since the WFSA does not accept any sequences with prefix "d"

    Args:
        context (list): A sequence of tokens in the WFSA's alphabet.

    Returns:
        (float): Log weight of `context` as a prefix under the WFSA.
    """
    return self._prefix(context)[0]

logw_next(context) async

Returns next token log weights given context.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the WFSA's alphabet.

required

Returns:

Type Description
LazyWeights

Log-weights for next token and EOS.

Source code in genlm/control/potential/built_in/wfsa.py
async def logw_next(self, context):
    """Returns next token log weights given `context`.

    Args:
        context (list): A sequence of tokens in the WFSA's alphabet.

    Returns:
        (LazyWeights): Log-weights for next token and EOS.
    """
    log_ctx_w, curr = self._prefix(context)

    if log_ctx_w == float("-inf"):
        raise ValueError(f"Context {context!r} has zero weight.")

    bkwd = self.wfsa.epsremove.backward

    ws = self.wfsa.R.chart()
    for i in curr:
        for b, j, w in self.wfsa.epsremove.arcs(i=i):
            ws[b] += curr[i] * w * bkwd[j]

    ws[self.eos] = self.wfsa.R.zero
    for j, w in self.wfsa.epsremove.F:
        ws[self.eos] += curr[j] * w

    log_ws = np.array([ws[b].score for b in self.vocab_eos]) - log_ctx_w

    return self.make_lazy_weights(log_ws)

BoolFSA

Bases: WFSA

Boolean FSA potential.

Source code in genlm/control/potential/built_in/wfsa.py
class BoolFSA(WFSA):
    """Boolean FSA potential."""

    async def prefix(self, context):
        """
        Computes whether the context is accepted as a prefix by the FSA.

        Args:
            context (list): A sequence of tokens in the WFSA's alphabet.

        Returns:
            (float): `0` if the context is accepted as a prefix, `-inf` otherwise.
        """
        prefix_w = await super().prefix(context)
        if prefix_w > float("-inf"):
            return 0
        return float("-inf")

    async def complete(self, context):
        """
        Computes whether the context is accepted by the FSA.

        Args:
            context (list): A sequence of tokens in the WFSA's alphabet.

        Returns:
            (float): `0` if the context is accepted, `-inf` otherwise.
        """
        complete_w = await super().complete(context)
        if complete_w > float("-inf"):
            return 0
        return float("-inf")

    async def logw_next(self, context):
        """
        Returns next token log weights given `context`.

        Args:
            context (list): A sequence of tokens in the WFSA's alphabet.

        Returns:
            (LazyWeights): Boolean log-weights for next token.
        """
        logw_next = await super().logw_next(context)
        return logw_next.spawn(
            new_weights=np.where(
                logw_next.weights > float("-inf"), 0, logw_next.weights
            )
        )

    async def batch_logw_next(self, contexts):
        """
        Returns next token log weights for a batch of contexts.

        Args:
            contexts (list): The list of contexts.

        Returns:
            (list): List of log-weights for next token, one per context.
        """
        logw_nexts = await super().batch_logw_next(contexts)
        return [
            logw_next.spawn(
                new_weights=np.where(
                    logw_next.weights > float("-inf"), 0, logw_next.weights
                )
            )
            for logw_next in logw_nexts
        ]

    def __repr__(self):
        return f"BoolFSA(wfsa={self.wfsa!r})"

prefix(context) async

Computes whether the context is accepted as a prefix by the FSA.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the WFSA's alphabet.

required

Returns:

Type Description
float

0 if the context is accepted as a prefix, -inf otherwise.

Source code in genlm/control/potential/built_in/wfsa.py
async def prefix(self, context):
    """
    Computes whether the context is accepted as a prefix by the FSA.

    Args:
        context (list): A sequence of tokens in the WFSA's alphabet.

    Returns:
        (float): `0` if the context is accepted as a prefix, `-inf` otherwise.
    """
    prefix_w = await super().prefix(context)
    if prefix_w > float("-inf"):
        return 0
    return float("-inf")

complete(context) async

Computes whether the context is accepted by the FSA.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the WFSA's alphabet.

required

Returns:

Type Description
float

0 if the context is accepted, -inf otherwise.

Source code in genlm/control/potential/built_in/wfsa.py
async def complete(self, context):
    """
    Computes whether the context is accepted by the FSA.

    Args:
        context (list): A sequence of tokens in the WFSA's alphabet.

    Returns:
        (float): `0` if the context is accepted, `-inf` otherwise.
    """
    complete_w = await super().complete(context)
    if complete_w > float("-inf"):
        return 0
    return float("-inf")

logw_next(context) async

Returns next token log weights given context.

Parameters:

Name Type Description Default
context list

A sequence of tokens in the WFSA's alphabet.

required

Returns:

Type Description
LazyWeights

Boolean log-weights for next token.

Source code in genlm/control/potential/built_in/wfsa.py
async def logw_next(self, context):
    """
    Returns next token log weights given `context`.

    Args:
        context (list): A sequence of tokens in the WFSA's alphabet.

    Returns:
        (LazyWeights): Boolean log-weights for next token.
    """
    logw_next = await super().logw_next(context)
    return logw_next.spawn(
        new_weights=np.where(
            logw_next.weights > float("-inf"), 0, logw_next.weights
        )
    )

batch_logw_next(contexts) async

Returns next token log weights for a batch of contexts.

Parameters:

Name Type Description Default
contexts list

The list of contexts.

required

Returns:

Type Description
list

List of log-weights for next token, one per context.

Source code in genlm/control/potential/built_in/wfsa.py
async def batch_logw_next(self, contexts):
    """
    Returns next token log weights for a batch of contexts.

    Args:
        contexts (list): The list of contexts.

    Returns:
        (list): List of log-weights for next token, one per context.
    """
    logw_nexts = await super().batch_logw_next(contexts)
    return [
        logw_next.spawn(
            new_weights=np.where(
                logw_next.weights > float("-inf"), 0, logw_next.weights
            )
        )
        for logw_next in logw_nexts
    ]

CanonicalTokenization

Bases: Potential

A custom potential that enforces canonical BPE tokenization.

This potential ensures that tokens follow the canonical tokenization rules by using the FastCanonicalityFilterBPE under the hood.

Source code in genlm/control/potential/built_in/canonical.py
class CanonicalTokenization(Potential):
    """
    A custom potential that enforces canonical BPE tokenization.

    This potential ensures that tokens follow the canonical tokenization rules
    by using the FastCanonicalityFilterBPE under the hood.
    """

    def __init__(self, canonicality_filter):
        """
        Initialize the Canonical Potential

        Args:
            canonicality_filter (FastCanonicalityFilterBPE): An initialized FastCanonicalityFilterBPE instance.
        """
        # Store the pre-initialized filter and tokenizer
        self.canonicality_filter = canonicality_filter

        # IMPORTANT: In the base Potential class, EOS will be added to vocab automatically
        # So we should NOT add it ourselves to the vocabulary we pass to super().__init__
        # Use Token objects directly as vocabulary to maintain token_id information
        vocabulary = self.canonicality_filter._decode
        super().__init__(vocabulary)

    @classmethod
    def from_llm(cls, llm):
        """
        Factory method to create CanonicalTokenization from a PromptedLLM instance.

        Args:
            llm (PromptedLLM): An instance of PromptedLLM containing the model and tokenizer.

        Returns:
            (CanonicalTokenization): An initialized CanonicalTokenization instance.
        """
        if not isinstance(llm, PromptedLLM):
            raise TypeError(
                f"Expected llm to be an instance of PromptedLLM, got {type(llm)}"
            )

        # Extract necessary components from llm
        tokenizer = llm.model.tokenizer
        eos_token_ids = llm.token_maps.eos_idxs
        model_name = tokenizer.name_or_path

        # Create the filter using its factory method
        canonicality_filter = FastCanonicalityFilterBPE.from_tokenizer(
            tokenizer, eos_token_ids
        )

        # Set overrides on the filter
        canonicality_filter.set_overrides(model_name)

        # Call __init__ with the created filter and tokenizer
        return cls(canonicality_filter)

    async def complete(self, context):
        """
        Assess if a complete sequence follows canonical tokenization.

        Args:
            context (list): Sequence of tokens

        Returns:
            (float): 0.0 if canonical, float('-inf') otherwise
        """
        # Empty sequences are considered canonical
        if not context:
            return 0.0

        # Check if the sequence is canonical
        is_canonical = self._check_canonicality(context)
        return 0.0 if is_canonical else float("-inf")

    async def prefix(self, context):
        """
        Assess if a prefix sequence could potentially extend to a canonical sequence.
        For canonicality, this is the same as complete.

        Args:
            context (list): Sequence of tokens

        Returns:
            (float): 0.0 if potentially canonical, float('-inf') otherwise
        """
        return await self.complete(context)

    async def logw_next(self, context):
        """
        Compute weights for each possible next token given the context.

        Args:
            context (list): Sequence of tokens

        Returns:
            (LazyWeights): Weights for each token in the vocabulary and EOS
        """
        # Get the prefix weight (to check if context itself is canonical)
        ctx_log_w = await self.prefix(context)

        if ctx_log_w == float("-inf"):
            raise ValueError("Context is non-canonical")
        else:
            if context:
                t = (None, context[-1])
                filter_mask = self.canonicality_filter(t)
            else:
                filter_mask = np.ones(len(self.canonicality_filter._decode), dtype=bool)

            # Create log weights directly instead of using np.log(filter_mask)
            # This is more efficient, avoids torch (with torch can't combine with other potentials!)
            logws_no_eos = np.where(filter_mask, 0.0, float("-inf")).astype(np.float32)

            # append eos to the logws, always allow eos.
            # NOTE: concat is because ._decode does not include eos while .vocab_eos does
            logws = np.concatenate([logws_no_eos, np.array([0.0], dtype=np.float32)])

        return self.make_lazy_weights(logws)

    def _check_canonicality(self, context):
        """
        Check if a sequence follows canonical tokenization.

        Args:
            context (list): Sequence of tokens

        Returns:
            (bool): True if the sequence is canonical, False otherwise
        """
        # If we're checking a single token, it's always canonical
        if len(context) <= 1:
            return True

        # Check all adjacent token pairs for canonicality
        for i in range(1, len(context)):
            prev_token = context[i - 1]
            current_token = context[i]

            # Format expected by the filter: (None, previous_token)
            t = (None, prev_token)
            mask = self.canonicality_filter(t)
            # print("percent of mask: ", np.sum(mask)*100 / len(mask))

            # Find token_id in the canonicality filter's vocabulary
            current_token_bytes = Token.as_bytes(current_token)

            token_id = self.canonicality_filter._encode[current_token_bytes]
            if not mask[token_id]:
                return False

        return True

__init__(canonicality_filter)

Initialize the Canonical Potential

Parameters:

Name Type Description Default
canonicality_filter FastCanonicalityFilterBPE

An initialized FastCanonicalityFilterBPE instance.

required
Source code in genlm/control/potential/built_in/canonical.py
def __init__(self, canonicality_filter):
    """
    Initialize the Canonical Potential

    Args:
        canonicality_filter (FastCanonicalityFilterBPE): An initialized FastCanonicalityFilterBPE instance.
    """
    # Store the pre-initialized filter and tokenizer
    self.canonicality_filter = canonicality_filter

    # IMPORTANT: In the base Potential class, EOS will be added to vocab automatically
    # So we should NOT add it ourselves to the vocabulary we pass to super().__init__
    # Use Token objects directly as vocabulary to maintain token_id information
    vocabulary = self.canonicality_filter._decode
    super().__init__(vocabulary)

from_llm(llm) classmethod

Factory method to create CanonicalTokenization from a PromptedLLM instance.

Parameters:

Name Type Description Default
llm PromptedLLM

An instance of PromptedLLM containing the model and tokenizer.

required

Returns:

Type Description
CanonicalTokenization

An initialized CanonicalTokenization instance.

Source code in genlm/control/potential/built_in/canonical.py
@classmethod
def from_llm(cls, llm):
    """
    Factory method to create CanonicalTokenization from a PromptedLLM instance.

    Args:
        llm (PromptedLLM): An instance of PromptedLLM containing the model and tokenizer.

    Returns:
        (CanonicalTokenization): An initialized CanonicalTokenization instance.
    """
    if not isinstance(llm, PromptedLLM):
        raise TypeError(
            f"Expected llm to be an instance of PromptedLLM, got {type(llm)}"
        )

    # Extract necessary components from llm
    tokenizer = llm.model.tokenizer
    eos_token_ids = llm.token_maps.eos_idxs
    model_name = tokenizer.name_or_path

    # Create the filter using its factory method
    canonicality_filter = FastCanonicalityFilterBPE.from_tokenizer(
        tokenizer, eos_token_ids
    )

    # Set overrides on the filter
    canonicality_filter.set_overrides(model_name)

    # Call __init__ with the created filter and tokenizer
    return cls(canonicality_filter)

complete(context) async

Assess if a complete sequence follows canonical tokenization.

Parameters:

Name Type Description Default
context list

Sequence of tokens

required

Returns:

Type Description
float

0.0 if canonical, float('-inf') otherwise

Source code in genlm/control/potential/built_in/canonical.py
async def complete(self, context):
    """
    Assess if a complete sequence follows canonical tokenization.

    Args:
        context (list): Sequence of tokens

    Returns:
        (float): 0.0 if canonical, float('-inf') otherwise
    """
    # Empty sequences are considered canonical
    if not context:
        return 0.0

    # Check if the sequence is canonical
    is_canonical = self._check_canonicality(context)
    return 0.0 if is_canonical else float("-inf")

prefix(context) async

Assess if a prefix sequence could potentially extend to a canonical sequence. For canonicality, this is the same as complete.

Parameters:

Name Type Description Default
context list

Sequence of tokens

required

Returns:

Type Description
float

0.0 if potentially canonical, float('-inf') otherwise

Source code in genlm/control/potential/built_in/canonical.py
async def prefix(self, context):
    """
    Assess if a prefix sequence could potentially extend to a canonical sequence.
    For canonicality, this is the same as complete.

    Args:
        context (list): Sequence of tokens

    Returns:
        (float): 0.0 if potentially canonical, float('-inf') otherwise
    """
    return await self.complete(context)

logw_next(context) async

Compute weights for each possible next token given the context.

Parameters:

Name Type Description Default
context list

Sequence of tokens

required

Returns:

Type Description
LazyWeights

Weights for each token in the vocabulary and EOS

Source code in genlm/control/potential/built_in/canonical.py
async def logw_next(self, context):
    """
    Compute weights for each possible next token given the context.

    Args:
        context (list): Sequence of tokens

    Returns:
        (LazyWeights): Weights for each token in the vocabulary and EOS
    """
    # Get the prefix weight (to check if context itself is canonical)
    ctx_log_w = await self.prefix(context)

    if ctx_log_w == float("-inf"):
        raise ValueError("Context is non-canonical")
    else:
        if context:
            t = (None, context[-1])
            filter_mask = self.canonicality_filter(t)
        else:
            filter_mask = np.ones(len(self.canonicality_filter._decode), dtype=bool)

        # Create log weights directly instead of using np.log(filter_mask)
        # This is more efficient, avoids torch (with torch can't combine with other potentials!)
        logws_no_eos = np.where(filter_mask, 0.0, float("-inf")).astype(np.float32)

        # append eos to the logws, always allow eos.
        # NOTE: concat is because ._decode does not include eos while .vocab_eos does
        logws = np.concatenate([logws_no_eos, np.array([0.0], dtype=np.float32)])

    return self.make_lazy_weights(logws)