Potentials
Potentials are the core object in genlm-control. A potential encodes constraints or preferences by assigning non-negative weights to sequences of tokens.
Potentials guide text generation by:
- Acting as components of samplers, which serve to propose new tokens at each step of the generation process.
- Serving as critics, which serve to reweight sequences based on whether they satisfy the constraint encoded by the potential at each step of the generation process.
Key concepts
Vocabulary
Each potential has a vocabulary which defines the set of tokens it operates on. Most built-in potentials operate on vocabularies whose tokens are bytes or int objects (the latter often representing individual bytes).
Weight assignment
Potentials assign weights to sequences of tokens from their vocabulary. These weights are always non-negative real numbers, though they are computed in log space for numerical stability.
A potential defines two core weighting functions:
-
complete- Assigns weights to sequences that are considered "finished" or "complete". For example, a potential enforcing grammatical correctness would assign positive weights to grammatically valid sentences and zero weights (negative infinity in log space) to invalid ones. -
prefix- Assigns weights to partial sequences that could potentially be extended into valid complete sequences. For example, a potential enforcing grammatical correctness would assign positive weights to prefixes of grammatically valid sequences.Given a complete method, there are many possible prefix methods that could be used, providing as much or as little information as desired. The key requirement is that if a prefix has zero weight, then all of its extensions and completions must also have zero weight - in other words, prefix cannot rule out sequences that could later become valid.
The relationship between complete and prefix weights is formalized in the Formalization section.
Next-token weights
Potentials also implement a logw_next method, which computes weights for each possible next token in the potential's vocabulary (and a reserved end-of-sequence token) given a context sequence. These weights are crucial for controlled text generation as they can be used to guide the selection of the next token at each step.
The logw_next method is implemented by default in terms of the complete and prefix methods. Potentials will often override this method to provide a more efficient implementation. However, logw_next must satisfy a contract with complete/prefix, specified in Formalization.
Batch methods
For improved performance with large batches of inputs, potentials support batch operations:
batch_complete(contexts)batch_prefix(contexts)batch_logw_next(contexts)batch_score(contexts)
By default, these methods simply call the corresponding non-batch method for all inputs, but potentials can override them to provide more efficient implementations. They can be used in conjunction with auto batching for improved performance during generation.
Built-in potentials
genlm-control comes with a number of built-in potentials that can be used in controlled text generation.
Language models
PromptedLLM represents a language model conditioned on a fixed prompt prefix.
# Load GPT-2 with temperature 0.5
llm = PromptedLLM.from_name("gpt2", temperature=0.5)
# Set a prompt prefix that all generations will be conditioned on
llm.set_prompt_from_str("Montreal is")
PromptedLLMs have a vocabulary of bytes tokens, obtained from the language model's tokenizer.
Finite-state automata
genlm-control provides two FSA implementations:
-
WFSA(Weighted Finite-State Automata) - For weighted constraints: -
BoolFSA(Boolean Finite-State Automata) - For hard constraints:
Both FSAs:
- Support regex patterns with standard syntax
- Operate on byte-level sequences by default
- Can be combined with other potentials via products
Context-free grammars
Similar to FSAs, genlm-control provides two CFG implementations:
-
WCFG(Weighted Context-Free Grammar). -
BoolCFG(Boolean Context-Free Grammar).
BoolCFGs support grammar specification via Lark syntax.
Both CFGs:
- Use Earley parsing for efficient recognition
- Can be combined with other potentials
- Operate on byte-level sequences by default
Note: It is recommended to specify grammars via lark syntax. The
from_stringmethod is provided for convenience, but it is not as flexible and robust.
Custom potentials
You can create custom potentials to implement specialized constraints or preferences that aren't covered by the built-in options.
Creating a custom potential
To define a custom potential:
- Create a subclass of
Potential - Implement the
completeandprefixmethods - Optionally override
logw_nextand the batch methods for performance optimization
When implementing custom potentials, the key is understanding the relationship between complete and prefix. Consider the following example of a potential that only allows sequences of a given length:
class LengthPotential(Potential):
""" A potential that only allows sequences of a given length. """
def __init__(self, vocabulary, length):
# Initialize the superclass with the potential's vocabulary.
super().__init__(vocabulary)
self.length = length
async def complete(self, context):
# Note: 0.0 = log(1.0) and float('-inf') = log(0.0)
return 0.0 if len(context) == self.length else float('-inf')
async def prefix(self, context):
# Note: 0.0 = log(1.0) and float('-inf') = log(0.0)
return 0.0 if len(context) <= self.length else float('-inf')
length_potential = LengthPotential(vocabulary=[b'the', b'a', b'cat', b'dog', b'saw', b'chased'], length=5)
This example illustrates the key difference between complete and prefix: the complete method only allows sequences of exactly the target length, while the prefix method allows any sequence that could potentially reach the target length (i.e., any sequence not exceeding the target length).
Common pitfalls
When implementing custom potentials, be aware of these common issues:
-
Inconsistent complete/prefix relationship - If your
prefixmethod assigns zero weight to a sequence, all extensions must also have zero weight. -
Inefficient implementations - For complex potentials, consider overriding
logw_nextwith a more efficient implementation than the default. -
Not handling async properly - All potential methods are asynchronous. Make sure to use
awaitwhen calling them and define your methods withasync def.
Testing your custom potential
Potentials automatically inherit from the PotentialTests mixin, which provides a number of tests for validating the correctness of the potential's implementation.
# These will raise an exception if the potential implementation does not satisfy the properties
await potential.assert_logw_next_consistency(context)
await potential.assert_autoreg_fact(context)
await potential.assert_batch_consistency(contexts)
Complex usage
Products of potentials
The Product class allows you to combine two potentials. A Product is itself is a potential, meaning that it implements all potential methods and that it is possible to chain products to combine more than two potentials.
# Example: Prompt intersection
mtl_llm = PromptedLLM.from_name("gpt2")
mtl_llm.set_prompt_from_str("Montreal is")
bos_llm = mtl_llm.spawn()
bos_llm.set_prompt_from_str("Boston is")
# Create product using multiplication operator
product = mtl_llm * bos_llm
The product potential operates on the intersection of the two potentials' vocabularies. For a product potential:
- The vocabulary \(\A\) is the intersection of the two potentials' vocabularies: \(\A = \A_1 \cap \A_2\).
- The prefix potential \(\prefix\) is the product (sum in log space) of the individual prefix potentials: \(\log \prefix(\xx) = \log \prefix_1(\xx) + \log \prefix_2(\xx)\).
- The complete potential \(\complete\) is the product (sum in log space) of the individual complete potentials: \(\log \complete(\xx) = \log \complete_1(\xx) + \log \complete_2(\xx)\).
- The next-token potential \(\pot(\cdot \mid \xx)\) is the product (sum in log space) of the individual next-token potentials: \(\log \pot(x \mid \xx) = \log \pot_1(x \mid \xx) + \log \pot_2(x \mid \xx)\) for \(x \in (\A_1 \cap \A_2) \cup \{\eos\}\)
Warning: Be careful when taking products of potentials with minimal vocabulary overlap, as the resulting potential will only operate on tokens present in both vocabularies. A warning will be raised if the vocabulary overlap is less than 10% of either potential's vocabulary.
Coerced potentials
The Coerced class allows you to adapt a potential to work with a different vocabulary using a coercion function. The coercion function must map between sequences in the new vocabulary and sequences in the potential's original vocabulary. This is particularly useful when combining potentials that operate on different types of tokens.
# Example: Coercing a byte-level FSA to work with a language model's tokens
fsa = BoolFSA.from_regex(r"\sthe\s(best|worst).*") # Works on bytes
llm = PromptedLLM.from_name("gpt2") # Works on byte sequences
# Coerce the FSA to work with the LLM's tokens by joining tokens into bytes
coerced_fsa = fsa.coerce(llm, f=b''.join)
# Now we can combine them using the product operator!
product = llm * coerced_fsa
Common use cases for coercion include:
- Adapting byte-level constraints (like FSAs) to work with token-level language models (which have vocabularies of byte sequences)
- Implementing constraints that operate on processed versions of the tokens (e.g., lowercase text)
- Converting between different tokenization schemes
Performance Note: The coercion operation can impact performance, especially when mapping from a coarser token type to a finer token type (e.g., byte sequences to individual bytes). To sample tokens from a coerced product, consider using specialized samplers (e.g.,
eager_token_sampler,topk_token_sampler).
Performance optimizations
genlm-control provides a number of performance optimizations for potentials, described in the performance section.
Formalization
This section provides a formal definition of potentials and the relationships between their complete, prefix, and next-token potentials.
Notation Let \(\A\) be a vocabulary of tokens and \(\eos\) a specialized end-of-sequence token. Let \(\A^*\) denote the set of all sequences of tokens which can be built from \(\A\) (including the empty sequence \(\epsilon\)) and \(\A^*{\eos} = \{\xx\eos : \xx \in \A^*\}\) the set of \(\eos\)-terminated sequences. We refer to \(\A^*\) as the set of prefix sequences and \(\A^*{\eos}\) the set of complete sequences.
A potential \(\pot\) is a function \(\pot: \A^* \cup\A^*{\eos} \rightarrow \mathbb{R}_{\geq 0}\) which assigns a non-negative real number to prefix and complete sequences from its vocabulary \(\A\):
where
- \(\prefix : \A^* \rightarrow \mathbb{R}_{\geq 0}\) is the prefix potential
- \(\complete : \A^* \rightarrow \mathbb{R}_{\geq 0}\) is the complete potential
The complete and prefix potentials are related by the following equality:
Intuitively, this means that the prefix potential cannot rule out a sequence which can later on turn out to be valid according to the complete potential.
Finally, we define the next-token weights function \(\pot(x \mid \xx) : \A \cup \{\eos\} \rightarrow \mathbb{R}_{\geq 0}\), which assigns a non-negative real number to each token \(x \in \A \cup \{\eos\}\) given a sequence \(\xx \in \A^*\):
\(\pot(\cdot \mid \xx)\) is related to the complete and prefix potentials according to the following autoregressive factorization:
Correspondance with the Potential class
Each of the quantities above directly corresponds to a method or attribute of the Potential class:
| Method/Attribute | Mathematical Quantity | Description |
|---|---|---|
vocab |
\(\A\) | The vocabulary of the potential. |
eos |
\(\eos\) | The end-of-sequence token. |
vocab_eos |
\(\A \cup \{\eos\}\) | The vocabulary of the potential including the end-of-sequence token. |
complete(self, context) |
\(\log \complete(\xx)\) | The complete potential for a given sequence. |
prefix(self, context) |
\(\log \prefix(\xx)\) | The prefix potential for a given sequence. |
logw_next(self, context) |
\(\log \pot(\cdot \mid \xx)\) | The next-token potential for a given prefix sequence. |
score(self, context) |
\(\log \pot(\xx)\) | The potential, dispatching to complete for eos-terminated sequences and prefix otherwise. |