Goal Inference¶
This example shows how to evaluate a a genlm.control model on the Planetarium Blocksworld goal-inference task.
Task: Given a natural-language description, generate the PDDL goal S-expression that completes (:goal (and …)) for a provided problem.
Data: [Planetarium] (Zuo et al., 2024) Blocksworld subset, filtered to goals of the form (:goal (and …)) and small instances (e.g., < 10 objects).
Setup¶
First, install the dependencies for this domain. In the root directory, run:
pip install -e .[goal_inference]
Install Fast Downward and VAL¶
The benchmarks requires the Fast-Downward planner, and the VAL plan validator.
To install them on Linux, follow the instructions on https://github.com/BatsResearch/planetarium:
apptainer pull fast-downward.sif docker://aibasel/downward:latest
mkdir tmp
curl -o tmp/VAL.zip https://dev.azure.com/schlumberger/4e6bcb11-cd68-40fe-98a2-e3777bfec0a6/_apis/build/builds/77/artifacts?artifactName=linux64\&api-version=7.1\&%24format=zip
unzip tmp/VAL.zip -d tmp/
tar -xzvf tmp/linux64/*.tar.gz -C tmp/ --strip-components=1
For other platforms follow the instructions under https://github.com/aibasel/downward/blob/main/BUILD.md
Make sure to add fast-downward.sif and VAL to your PATH or make aliases.
import os
from pathlib import Path
os.environ["PATH"] = str(Path("../../../tmp/bin")) + os.pathsep + os.environ["PATH"]
Verify the commands are working¶
!Validate -h | head -n 5
!../../../fast-downward.sif -h | head -n 5
VAL: The PDDL+ plan validation tool
Version 4: Validates continuous effects, events and processes.
Authors: Derek Long, Richard Howey, Stephen Cresswell and Maria Fox
https:://github/KCL-Planning/VAL
usage: fast-downward.py [-h] [-v] [--show-aliases] [--run-all] [--translate]
[--search]
[--translate-time-limit TRANSLATE_TIME_LIMIT]
[--translate-memory-limit TRANSLATE_MEMORY_LIMIT]
[--search-time-limit SEARCH_TIME_LIMIT]
Usage¶
Initialize the dataset and evaluator¶
from genlm.eval.domains.goal_inference import (
GoalInferenceDataset,
GoalInferenceEvaluator,
)
dataset = GoalInferenceDataset.from_hf_planetarium(
n_examples=5, max_objects=2, domains=["blocksworld"]
)
evaluator = GoalInferenceEvaluator()
Define a model adaptor¶
A model adaptor is an async callable that takes a PatternMatchingInstance and returns a ModelOutput. Here we'll use a genlm.control.PromptedLLM constrained to PDDL goals (via the GoalInferenceVALPotential potential).
from genlm.control import PromptedLLM, AWRS
from genlm.eval import ModelOutput, ModelResponse
from genlm.eval.domains.goal_inference import (
goal_default_prompt_formatter,
GoalInferenceVALPotential,
)
# Read domain pddl.
with open("../../../assets/goal_inference/pddl_domains/blocksworld.pddl") as f:
domain_text = f.read()
# Load an LLM
LLM = PromptedLLM.from_name("meta-llama/Meta-Llama-3-8B")
async def model(instance, output_dir, replicate):
# Set the prompt for the LLM.
LLM.prompt_ids = goal_default_prompt_formatter(
LLM.model.tokenizer, instance, use_chat_format=False
)
# Construct goal validation potential.
potential = GoalInferenceVALPotential(
domain_pddl_text=domain_text,
problem_pddl_text=instance.problem_text,
fast_downward_cmd="../../../fast-downward.sif",
val_cmd="Validate",
cache_root=".cache",
verbosity=0,
).coerce(LLM, f=b"".join)
# Define an adaptive weighted rejection sampler to sample tokens from the constrained model.
sampler = AWRS(LLM, potential)
# Run SMC to sample sequences from the constrained model.
sequences = await sampler.smc(
n_particles=10,
ess_threshold=0.9,
max_tokens=100,
)
# Return the sampled sequences and their probabilities as a ModelOutput.
return ModelOutput(
responses=[
ModelResponse(response=sequence, weight=prob)
for sequence, prob in sequences.decoded_posterior.items()
],
)
Run the evaluation¶
from genlm.eval import run_evaluation
results = await run_evaluation(
dataset=dataset,
model=model,
evaluator=evaluator,
max_instances=3,
n_replicates=1,
verbosity=3,
#output_dir="goal_inference_results", #optionally save the results to a directory
)
Instance GoalInferenceInstance(id=9504, domain=blocksworld) Mean weighted accuracy (instance): 0.8774042329086134 Mean weighted accuracy (total): 0.8774042329086134 Instance GoalInferenceInstance(id=8740, domain=blocksworld) Mean weighted accuracy (instance): 0.9999999999999999 Mean weighted accuracy (total): 0.9387021164543066 Instance GoalInferenceInstance(id=8744, domain=blocksworld) Mean weighted accuracy (instance): 1.0 Mean weighted accuracy (total): 0.959134744302871
References¶
Max Zuo, Francisco Piedrahita Velez, Xiaochen Li, Michael L. Littman, and Stephen H. Bach. Planetarium: a rigorous benchmark for translating text to structured planning languages. arXiv preprint arXiv:2407.03321, 2024. URL https://arxiv.org/abs/2407.03321