Evaluate Models Using Custom Constraints¶
This notebook shows how to evaluate models using custom constraints implemented with genlm-control. It builds upon the custom domain tutorial in custom_domains.ipynb, which covers creating dataset classes and evaluators.
This tutorial covers:
- Implementing custom potentials to encode constraints
- Implementing a model adaptor that uses constrained generation
- Running the evaluation
The following example demonstrates these steps on the pattern-matching domain, generating strings that conform to regex pattern specifications.
1. Implement custom potentials¶
A potential encodes constraints or preferences by assigning non-negative weights to sequences of tokens. Potentials can be used as components of samplers to propose new tokens at each step of the generation process or serve as critics to reweight sequences based on whether they satisfy the constraint encoded by the potential at each step.
Each potential has a vocabulary that specifies the set of tokens it operates on. Potentials must implement the prefix and the complete functions, which assign weights to partial and complete sequences, respectively. For a complete guide on implementing potentials, see documentation of GenLM Control.
Here we use a PatternPotential checks whether sequences fully match or remain consistent with the pattern-matching specification.
import string
import regex
from genlm.control import Potential
class PatternPotential(Potential):
"""Potential function for regex pattern matching."""
def __init__(self, pattern):
vocab = list(map(ord, string.printable))
super().__init__(vocab)
self.r = regex.compile(pattern)
async def complete(self, context):
text = "".join(map(chr, context))
match = self.r.fullmatch(text) is not None
return 0.0 if match else float("-inf")
async def prefix(self, context):
text = "".join(map(chr, context))
m = self.r.match(text, partial=True)
match = m is not None and m.start() == 0 and m.end() == len(text)
return 0.0 if match else float("-inf")
2. Implement a model adaptor¶
A model adaptor is an async callable that takes a PatternMatchingInstance and returns a ModelOutput. For this example, we'll use a constrained genlm.control.PromptedLLM to generate responses.
from genlm.control import PromptedLLM, AWRS
from genlm.eval import ModelOutput, ModelResponse
from genlm.eval.domains.pattern_matching import (
default_prompt_formatter,
)
# Load an LLM
LLM = PromptedLLM.from_name("gpt2", eos_tokens=[b"\n", b"\n\n"])
async def model(instance, output_dir, replicate):
# Set the prompt for the LLM.
LLM.prompt_ids = default_prompt_formatter(
LLM.model.tokenizer, instance, use_chat_format=False
)
# Define a potential that ensures the generated text matches the pattern
potential = PatternPotential(instance.pattern).coerce(LLM, f=b"".join)
# Define an adaptive weighted rejection sampler to sample tokens from the constrained model.
sampler = AWRS(LLM, potential)
# Run SMC to sample sequences from the constrained model.
sequences = await sampler.smc(
n_particles=5,
ess_threshold=0.5,
max_tokens=100,
)
# Return the sampled sequences and their probabilities as a ModelOutput.
return ModelOutput(
responses=[
ModelResponse(response=sequence, weight=prob)
for sequence, prob in sequences.decoded_posterior.items()
],
)
3. Run the evaluation¶
Using the dataset, evaluator, potential, and model adaptor, we can now run the evaluation:
from genlm.eval import run_evaluation
from genlm.eval.domains.pattern_matching import (
PatternMatchingDataset,
PatternMatchingEvaluator,
)
dataset = PatternMatchingDataset([r"xy|xz", r"ab|c(e|f)"])
evaluator = PatternMatchingEvaluator()
results = await run_evaluation(
dataset=dataset,
evaluator=evaluator,
model=model,
n_replicates=1,
verbosity=1,
# output_dir="results", # uncomment to save results
)
Instance instance_id=0 pattern='xy|xz' Mean weighted accuracy (instance): 1.0 Mean weighted accuracy (total): 1.0 Instance instance_id=1 pattern='ab|c(e|f)' Mean weighted accuracy (instance): 1.0 Mean weighted accuracy (total): 1.0