Custom Domains¶
This library is designed to be extensible to new domains. To evaluate a model on a custom domain, you need to:
- Define your dataset
- Implement an evaluator
- Implement a prompt formatter
- Implement a model adaptor
- Run the evaluation
The following example demonstrates these steps on the pattern matching domain.
1. Define your dataset¶
A dataset is an iterator over dataset instances satisfying a schema. The schema is defined by a class that inherits from Instance.
from genlm.eval import Instance
class PatternMatchingInstance(Instance):
"""Schema for a pattern matching instance."""
pattern: str
instance_id: int
def __repr__(self):
return f"pattern: {self.pattern} (id: {self.instance_id})"
Given a dataset schema, you can define a dataset by subclassing Dataset and implementing an __iter__ method which yields instances of the schema.
from genlm.eval import Dataset
class PatternMatchingDataset(Dataset[PatternMatchingInstance]):
"""Dataset for pattern matching evaluation."""
def __init__(self, patterns):
self.patterns = patterns
def __iter__(self):
"""Iterate over regex patterns.
Returns:
(Iterator[PatternMatchingInstance]): Iterator over regex instances.
"""
for pattern_id, pattern in enumerate(self.patterns):
yield PatternMatchingInstance(pattern=pattern, instance_id=pattern_id)
@property
def schema(self):
"""Get the schema class for this dataset."""
return PatternMatchingInstance
2. Implement an evaluator¶
An evaluator is the class responsible for scoring model outputs. Subclasses must minimally implement the evaluate_sample method which takes an instance and a response and returns an evaluation result.
import regex
from genlm.eval import Evaluator, EvaluationResult
class PatternMatchingEvaluator(Evaluator[PatternMatchingInstance]):
"""Evaluator for pattern matching."""
def evaluate_sample(self, instance, response):
"""Evaluate if a response matches the regex pattern."""
is_valid = regex.compile(instance.pattern).fullmatch(response) is not None
return EvaluationResult(
score=int(is_valid), desc="valid" if is_valid else "invalid"
)
3. Implement a prompt formatter¶
A prompt formatter tokenizes and standardizes the input to the model by optionally adding a system prompt and few-shot examples for the evaluation.
from genlm.eval.util import chat_template_messages
FEW_SHOT_EXAMPLES = [
("(ab)+", "ab"),
("(ab|cd)+", "cd"),
("[a-z]+", "hello"),
]
SYSTEM_PROMPT = (
"You are a helpful assistant that generates strings matching regular expressions. "
+ "Only output the exact string that matches the regex pattern, nothing more."
)
def default_prompt_formatter(
tokenizer,
instance,
use_chat_format=False,
system_prompt=SYSTEM_PROMPT,
few_shot_examples=FEW_SHOT_EXAMPLES,
):
"""Default prompt formatter for pattern matching.
Args:
tokenizer (Tokenizer): The tokenizer to use.
instance (PatternMatchingInstance): The instance to format.
use_chat_format (bool): Whether to use chat format.
system_prompt (str): The system prompt to use.
few_shot_examples (list[tuple[str, str]]): The few shot examples to use. Each example is a tuple of (pattern, response).
Returns:
(list[int]): The prompt ids.
"""
if use_chat_format:
return tokenizer.apply_chat_template(
chat_template_messages(
system_prompt,
few_shot_examples,
instance.pattern,
),
tokenize=True,
add_generation_prompt=True,
)
else:
return tokenizer.encode(
(
system_prompt
+ "\n"
+ "\n".join(
f"Pattern: {input}\nOutput: {output}"
for input, output in few_shot_examples
)
+ "\n"
+ instance.pattern
)
)
4. Implement a model adaptor¶
A model adaptor is an async callable that takes a dataset instance (here, a PatternMatchingInstance) and returns a ModelOutput.
For this example, we'll use a PromptedLLM that proposes tokens by sampling directly from the LM's distribution.
See custom_potentials.ipynb for a tutorial on how to implement custom constraints in genlm-control and evaluate the model.
from genlm.control import PromptedLLM, direct_token_sampler
from genlm.eval import ModelOutput, ModelResponse
# Load an LLM
LLM = PromptedLLM.from_name("gpt2", eos_tokens=[b"\n", b"\n\n"])
async def model(instance, output_dir, replicate):
# Set the prompt for the LLM.
LLM.prompt_ids = default_prompt_formatter(
LLM.model.tokenizer, instance, use_chat_format=False
)
# Load a sampler that proposes tokens by sampling directly from the LM's distribution
sampler = direct_token_sampler(LLM)
# Run SMC with 5 particles and a maximum of 25 tokens
sequences = await sampler.smc(
n_particles=5,
max_tokens=100,
ess_threshold=0.0,
)
# Return the sampled sequences and their probabilities as a ModelOutput.
return ModelOutput(
responses=[
ModelResponse(response=sequence, weight=prob)
for sequence, prob in sequences.decoded_posterior.items()
],
)
5. Run the evaluation¶
Using the dataset, evaluator, and model adaptor, we can now run the evaluation:
from genlm.eval import run_evaluation
dataset = PatternMatchingDataset([r"xy|xz", r"ab|c(e|f)"])
evaluator = PatternMatchingEvaluator()
results = await run_evaluation(
dataset=dataset,
evaluator=evaluator,
model=model,
n_replicates=1,
verbosity=1,
# output_dir="results", # uncomment to save results
)
Instance instance_id=0 pattern='xy|xz' Mean weighted accuracy (instance): 0.0 Mean weighted accuracy (total): 0.0 Instance instance_id=1 pattern='ab|c(e|f)' Mean weighted accuracy (instance): 0.0 Mean weighted accuracy (total): 0.0