Pattern Matching¶
This example shows how to evaluate a genlm.control model on the pattern matching domain.
- Task: Generate strings that conform to expressive pattern-matching specifications. Compared to formal regular expressions, these patterns contain explicit features that cannot be fully captured by deterministic finite-state automata, including unbounded center embedding and conditionals.
- Data: Over 400 pattern-matching specifications generated via the pipeline described in Appendix I of (Lipkin et al., 2025).
Setup¶
First, install the dependencies for this domain. In the root directory, run:
pip install -e .[pattern_matching]
Second, download the patterns.csv file from the assets/pattern_matching directory in the repository. (Note that you can also use your own patterns.)
Usage¶
This example shows how to evaluate a genlm.control model on the pattern matching domain.
Initialize the dataset and evaluator¶
In [1]:
Copied!
from genlm.eval.domains.pattern_matching import (
PatternMatchingDataset,
PatternMatchingEvaluator,
)
from genlm.eval.domains.pattern_matching import (
PatternMatchingDataset,
PatternMatchingEvaluator,
)
/opt/homebrew/Caskroom/miniconda/base/envs/genlm/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
In [2]:
Copied!
dataset = PatternMatchingDataset.from_csv(
"../../../assets/pattern_matching/patterns.csv", pattern_column="regex"
)
evaluator = PatternMatchingEvaluator()
dataset = PatternMatchingDataset.from_csv(
"../../../assets/pattern_matching/patterns.csv", pattern_column="regex"
)
evaluator = PatternMatchingEvaluator()
Define a model adaptor¶
A model adaptor is an async callable that takes a PatternMatchingInstance and returns a ModelOutput. For this example, we'll use a constrained genlm.control.PromptedLLM to generate responses.
In [3]:
Copied!
from genlm.control import PromptedLLM, AWRS
from genlm.eval import ModelOutput, ModelResponse
from genlm.eval.domains.pattern_matching import (
default_prompt_formatter,
PatternPotential,
)
# Load an LLM
LLM = PromptedLLM.from_name("gpt2", eos_tokens=[b"\n", b"\n\n"])
async def model(instance, output_dir, replicate):
# Set the prompt for the LLM.
LLM.prompt_ids = default_prompt_formatter(
LLM.model.tokenizer, instance, use_chat_format=False
)
# Define a potential that ensures the generated text matches the pattern
potential = PatternPotential(instance.pattern).coerce(LLM, f=b"".join)
# Define an adaptive weighted rejection sampler to sample tokens from the constrained model.
sampler = AWRS(LLM, potential)
# Run SMC to sample sequences from the constrained model.
sequences = await sampler.smc(
n_particles=5,
ess_threshold=0.5,
max_tokens=100,
)
# Return the sampled sequences and their probabilities as a ModelOutput.
return ModelOutput(
responses=[
ModelResponse(response=sequence, weight=prob)
for sequence, prob in sequences.decoded_posterior.items()
],
)
from genlm.control import PromptedLLM, AWRS
from genlm.eval import ModelOutput, ModelResponse
from genlm.eval.domains.pattern_matching import (
default_prompt_formatter,
PatternPotential,
)
# Load an LLM
LLM = PromptedLLM.from_name("gpt2", eos_tokens=[b"\n", b"\n\n"])
async def model(instance, output_dir, replicate):
# Set the prompt for the LLM.
LLM.prompt_ids = default_prompt_formatter(
LLM.model.tokenizer, instance, use_chat_format=False
)
# Define a potential that ensures the generated text matches the pattern
potential = PatternPotential(instance.pattern).coerce(LLM, f=b"".join)
# Define an adaptive weighted rejection sampler to sample tokens from the constrained model.
sampler = AWRS(LLM, potential)
# Run SMC to sample sequences from the constrained model.
sequences = await sampler.smc(
n_particles=5,
ess_threshold=0.5,
max_tokens=100,
)
# Return the sampled sequences and their probabilities as a ModelOutput.
return ModelOutput(
responses=[
ModelResponse(response=sequence, weight=prob)
for sequence, prob in sequences.decoded_posterior.items()
],
)
/opt/homebrew/Caskroom/miniconda/base/envs/genlm/lib/python3.11/site-packages/genlm/backend/tokenization/vocab.py:98: UserWarning: Duplicate tokens found in string vocabulary. This may lead to downstream issues with the string vocabulary; we recommend using the byte vocabulary. warnings.warn(
Run the evaluation¶
In [4]:
Copied!
from genlm.eval import run_evaluation
results = await run_evaluation(
dataset=dataset,
model=model,
evaluator=evaluator,
max_instances=5,
n_replicates=1,
verbosity=1,
# output_dir="pattern_matching_results", optionally save the results to a directory
)
from genlm.eval import run_evaluation
results = await run_evaluation(
dataset=dataset,
model=model,
evaluator=evaluator,
max_instances=5,
n_replicates=1,
verbosity=1,
# output_dir="pattern_matching_results", optionally save the results to a directory
)
Instance instance_id=0 pattern='(?<!\\d{3})abc(?!\\d{3})'
Mean weighted accuracy (instance): 1.0
Mean weighted accuracy (total): 1.0
Instance instance_id=1 pattern='^(?|(a)|(b)|(c))\\1$'
Mean weighted accuracy (instance): 1.0
Mean weighted accuracy (total): 1.0
Instance instance_id=2 pattern='[\\p{IsAlphabetic}&&[\\P{L}]]'
Mean weighted accuracy (instance): 1.0
Mean weighted accuracy (total): 1.0
Instance instance_id=3 pattern='^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$'
Mean weighted accuracy (instance): 1.0
Mean weighted accuracy (total): 1.0
Instance instance_id=4 pattern='^[a-f0-9]{8}-[a-f0-9]{4}-[1-5][a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$'
Mean weighted accuracy (instance): 0.9999999999999999
Mean weighted accuracy (total): 1.0
In [5]:
Copied!
results.keys()
results.keys()
Out[5]:
dict_keys(['average_weighted_accuracy', 'n_instances', 'all_instance_results', 'all_instance_outputs'])
In [ ]:
Copied!