Skip to content

goal_inference

genlm.eval.domains.goal_inference

GoalInferenceInstance

Bases: Instance

Schema for a single Planetarium goal-inference item.

Source code in genlm/eval/domains/goal_inference/goal_inference.py
class GoalInferenceInstance(Instance):
    """Schema for a single Planetarium goal-inference item."""

    nl_goal: str
    problem_text: str
    masked_pddl: str
    prefix_pddl: str
    domain_name: str

    def __str__(self):
        return (
            f"GoalInferenceInstance(id={self.instance_id}, domain={self.domain_name})"
        )

GoalInferenceDataset

Bases: Dataset[GoalInferenceInstance]

Dataset wrapper yielding GoalInferenceInstance items.

Source code in genlm/eval/domains/goal_inference/goal_inference.py
class GoalInferenceDataset(Dataset[GoalInferenceInstance]):
    """Dataset wrapper yielding GoalInferenceInstance items."""

    def __init__(self, dev_items: List[dict]):
        """Store preprocessed records."""
        self.dev_items = dev_items

    @staticmethod
    def _make_prefix_pddl(problem_text: str) -> Optional[str]:
        """Return text up to and including '(:goal (and' for prompting.

        Args:
            problem_text: Full PDDL problem text.

        Returns:
            Prefix string or None if the pattern is absent.
        """
        m = re.search(r"\(:goal\s*\(and", problem_text)
        if not m:
            return None
        return problem_text[: m.end()]

    @staticmethod
    def _mask_goal_for_reference(problem_text: str) -> Optional[str]:
        """Create a masked PDDL with '[BLANK]' in place of the goal.

        Args:
            problem_text: Full PDDL problem text.

        Returns:
            Masked PDDL or None if no goal section is found.
        """
        i = problem_text.find("(:goal")
        if i == -1:
            return None
        prefix_before_goal = problem_text[:i]
        goal_suffix = "(:goal (and [BLANK]))\n)"
        return prefix_before_goal + goal_suffix

    def __iter__(self):
        """Yield GoalInferenceInstance objects built from stored records."""
        for i, rec in enumerate(self.dev_items):
            problem_text = rec["problem_text"]
            prefix_pddl = self._make_prefix_pddl(problem_text)
            masked_pddl = self._mask_goal_for_reference(problem_text)
            if prefix_pddl is None or masked_pddl is None:
                continue

            yield GoalInferenceInstance(
                nl_goal=rec["nl_goal"],
                problem_text=problem_text,
                masked_pddl=masked_pddl,
                prefix_pddl=prefix_pddl,
                instance_id=rec.get("instance_id", i),
                domain_name=rec["domain_name"],
            )

    @classmethod
    def from_hf_planetarium(
        cls,
        n_examples: int = 100,
        max_objects: int = 9,
        shard_filename: str = "data/train-00000-of-00001.parquet",
        domains: Optional[List[str]] = None,
    ) -> "GoalInferenceDataset":
        """Load and filter Planetarium data via HuggingFace.

        Args:
            n_examples: Number of instances to evaluate.
            max_objects: Keep problems with at most this many objects.
            shard_filename: Specific shard file to download from Planetarium.
            domains: Optional list of domain names to include.

        Returns:
            GoalInferenceDataset with filtered instances.
        """
        local_path = hf_hub_download(
            repo_id="BatsResearch/planetarium",
            repo_type="dataset",
            filename=shard_filename,
        )

        allowed = {"blocksworld"} if domains is None else {d.lower() for d in domains}

        df = pl.read_parquet(local_path)
        df = (
            df.with_columns(
                problem_pddl=pl.col("problem_pddl").str.replace(
                    "\r\n", "\n", literal=True
                ),
                natural_language=pl.col("natural_language").fill_null(""),
            )
            .filter(pl.col("problem_pddl").str.contains("(:goal (and", literal=True))
            .with_columns(
                goal_natural_language=pl.concat_str(
                    pl.lit("Your goal"),
                    pl.col("natural_language").str.split(by="Your goal").list.last(),
                ),
            )
            .filter(
                (pl.col("domain").str.to_lowercase().is_in(list(allowed)))
                & (pl.col("num_objects") <= max_objects)
                & (pl.col("init_is_abstract") == 0)
                & (pl.col("goal_is_abstract") == 0)
            )
            .unique(subset=["goal_natural_language"])
            .sample(fraction=1, shuffle=True, seed=1234)
            .head(n_examples)
            .sort("id")
            .select(
                pl.col("id").alias("instance_id"),
                pl.col("goal_natural_language").alias("nl_goal"),
                pl.col("problem_pddl").alias("problem_text"),
                pl.col("domain").str.to_lowercase().alias("domain_name"),
            )
        )

        items = df.to_dicts()
        return cls(items)

    @property
    def schema(self):
        return GoalInferenceInstance

__init__(dev_items)

Store preprocessed records.

Source code in genlm/eval/domains/goal_inference/goal_inference.py
def __init__(self, dev_items: List[dict]):
    """Store preprocessed records."""
    self.dev_items = dev_items

__iter__()

Yield GoalInferenceInstance objects built from stored records.

Source code in genlm/eval/domains/goal_inference/goal_inference.py
def __iter__(self):
    """Yield GoalInferenceInstance objects built from stored records."""
    for i, rec in enumerate(self.dev_items):
        problem_text = rec["problem_text"]
        prefix_pddl = self._make_prefix_pddl(problem_text)
        masked_pddl = self._mask_goal_for_reference(problem_text)
        if prefix_pddl is None or masked_pddl is None:
            continue

        yield GoalInferenceInstance(
            nl_goal=rec["nl_goal"],
            problem_text=problem_text,
            masked_pddl=masked_pddl,
            prefix_pddl=prefix_pddl,
            instance_id=rec.get("instance_id", i),
            domain_name=rec["domain_name"],
        )

from_hf_planetarium(n_examples=100, max_objects=9, shard_filename='data/train-00000-of-00001.parquet', domains=None) classmethod

Load and filter Planetarium data via HuggingFace.

Parameters:

Name Type Description Default
n_examples int

Number of instances to evaluate.

100
max_objects int

Keep problems with at most this many objects.

9
shard_filename str

Specific shard file to download from Planetarium.

'data/train-00000-of-00001.parquet'
domains Optional[List[str]]

Optional list of domain names to include.

None

Returns:

Type Description
GoalInferenceDataset

GoalInferenceDataset with filtered instances.

Source code in genlm/eval/domains/goal_inference/goal_inference.py
@classmethod
def from_hf_planetarium(
    cls,
    n_examples: int = 100,
    max_objects: int = 9,
    shard_filename: str = "data/train-00000-of-00001.parquet",
    domains: Optional[List[str]] = None,
) -> "GoalInferenceDataset":
    """Load and filter Planetarium data via HuggingFace.

    Args:
        n_examples: Number of instances to evaluate.
        max_objects: Keep problems with at most this many objects.
        shard_filename: Specific shard file to download from Planetarium.
        domains: Optional list of domain names to include.

    Returns:
        GoalInferenceDataset with filtered instances.
    """
    local_path = hf_hub_download(
        repo_id="BatsResearch/planetarium",
        repo_type="dataset",
        filename=shard_filename,
    )

    allowed = {"blocksworld"} if domains is None else {d.lower() for d in domains}

    df = pl.read_parquet(local_path)
    df = (
        df.with_columns(
            problem_pddl=pl.col("problem_pddl").str.replace(
                "\r\n", "\n", literal=True
            ),
            natural_language=pl.col("natural_language").fill_null(""),
        )
        .filter(pl.col("problem_pddl").str.contains("(:goal (and", literal=True))
        .with_columns(
            goal_natural_language=pl.concat_str(
                pl.lit("Your goal"),
                pl.col("natural_language").str.split(by="Your goal").list.last(),
            ),
        )
        .filter(
            (pl.col("domain").str.to_lowercase().is_in(list(allowed)))
            & (pl.col("num_objects") <= max_objects)
            & (pl.col("init_is_abstract") == 0)
            & (pl.col("goal_is_abstract") == 0)
        )
        .unique(subset=["goal_natural_language"])
        .sample(fraction=1, shuffle=True, seed=1234)
        .head(n_examples)
        .sort("id")
        .select(
            pl.col("id").alias("instance_id"),
            pl.col("goal_natural_language").alias("nl_goal"),
            pl.col("problem_pddl").alias("problem_text"),
            pl.col("domain").str.to_lowercase().alias("domain_name"),
        )
    )

    items = df.to_dicts()
    return cls(items)

GoalInferenceEvaluator

Bases: Evaluator[GoalInferenceInstance]

Evaluator using Planetarium equivalence on masked PDDL reconstruction.

Source code in genlm/eval/domains/goal_inference/goal_inference.py
class GoalInferenceEvaluator(Evaluator[GoalInferenceInstance]):
    """Evaluator using Planetarium equivalence on masked PDDL reconstruction."""

    def evaluate_sample(
        self, instance: GoalInferenceInstance, response: str
    ) -> EvaluationResult:
        """Inject prediction into masked PDDL and check equivalence.

        Args:
            instance (GoalInferenceInstance): The goal-inference item being evaluated.
            response (str): Model output to splice into the goal (no closing paren).

        Returns:
            EvaluationResult with score 1.0 if equivalent, else 0.0.
        """
        masked = instance.masked_pddl
        full_pddl = instance.problem_text
        if not masked or not full_pddl:
            return EvaluationResult(score=0.0, desc="missing_problem_or_masked")

        if "[BLANK]" not in masked:
            return EvaluationResult(score=0.0, desc="no_blank_marker")

        pred = response.strip() if response is not None else ""
        generated_pddl = masked.replace("[BLANK]", pred + ")")

        try:
            ok = planetarium.evaluate(full_pddl, generated_pddl)[2]
        except (ValueError, AttributeError):
            return EvaluationResult(score=0.0, desc="planetarium_error")

        return EvaluationResult(
            score=1.0 if ok else 0.0,
            desc="equiv" if ok else "not_equiv",
            metadata={"candidate": generated_pddl},
        )

evaluate_sample(instance, response)

Inject prediction into masked PDDL and check equivalence.

Parameters:

Name Type Description Default
instance GoalInferenceInstance

The goal-inference item being evaluated.

required
response str

Model output to splice into the goal (no closing paren).

required

Returns:

Type Description
EvaluationResult

EvaluationResult with score 1.0 if equivalent, else 0.0.

Source code in genlm/eval/domains/goal_inference/goal_inference.py
def evaluate_sample(
    self, instance: GoalInferenceInstance, response: str
) -> EvaluationResult:
    """Inject prediction into masked PDDL and check equivalence.

    Args:
        instance (GoalInferenceInstance): The goal-inference item being evaluated.
        response (str): Model output to splice into the goal (no closing paren).

    Returns:
        EvaluationResult with score 1.0 if equivalent, else 0.0.
    """
    masked = instance.masked_pddl
    full_pddl = instance.problem_text
    if not masked or not full_pddl:
        return EvaluationResult(score=0.0, desc="missing_problem_or_masked")

    if "[BLANK]" not in masked:
        return EvaluationResult(score=0.0, desc="no_blank_marker")

    pred = response.strip() if response is not None else ""
    generated_pddl = masked.replace("[BLANK]", pred + ")")

    try:
        ok = planetarium.evaluate(full_pddl, generated_pddl)[2]
    except (ValueError, AttributeError):
        return EvaluationResult(score=0.0, desc="planetarium_error")

    return EvaluationResult(
        score=1.0 if ok else 0.0,
        desc="equiv" if ok else "not_equiv",
        metadata={"candidate": generated_pddl},
    )

goal_default_prompt_formatter(tokenizer, instance, use_chat_format=False, system_prompt=GOAL_SYSTEM_PROMPT)

Format the prompt to reproduce the reference assistant-prefix prompting.

Parameters:

Name Type Description Default
tokenizer Tokenizer

The tokenizer to use.

required
instance GoalInferenceInstance

The instance to format.

required
use_chat_format bool

Whether to use chat format.

False
system_prompt str

The system prompt to use.

GOAL_SYSTEM_PROMPT

Returns:

Type Description
list[int]

The prompt ids.

Source code in genlm/eval/domains/goal_inference/goal_inference.py
def goal_default_prompt_formatter(
    tokenizer,
    instance: GoalInferenceInstance,
    use_chat_format: bool = False,
    system_prompt: str = GOAL_SYSTEM_PROMPT,
):
    """Format the prompt to reproduce the reference assistant-prefix prompting.

    Args:
        tokenizer (Tokenizer): The tokenizer to use.
        instance (GoalInferenceInstance): The instance to format.
        use_chat_format (bool): Whether to use chat format.
        system_prompt (str): The system prompt to use.

    Returns:
        (list[int]): The prompt ids.
    """
    if use_chat_format:
        messages = [
            {"role": "system", "content": system_prompt},
            {
                "role": "user",
                "content": "Natural Language goal description: \n\n"
                + instance.nl_goal
                + "\n\n",
            },
            {"role": "assistant", "content": instance.prefix_pddl},
        ]
        return tokenizer.apply_chat_template(
            conversation=messages, tokenize=True, add_generation_prompt=True
        )

    prompt = (
        system_prompt
        + "Natural Language goal description: \n\n"
        + instance.nl_goal
        + "\n\n"
        + instance.prefix_pddl
    )
    return tokenizer.encode(prompt)

GoalInferenceVALPotential

Bases: Potential

Expensive potential that validates partial goal strings with VAL.

It splices the current candidate goal into the problem, reuses a (cached) Fast-Downward plan for (domain, problem), and returns 0.0 if VAL validates the plan under the candidate goal (-inf otherwise).

Source code in genlm/eval/domains/goal_inference/goal_potential.py
class GoalInferenceVALPotential(Potential):
    """Expensive potential that validates partial goal strings with VAL.

    It splices the current candidate goal into the problem, reuses a (cached)
    Fast-Downward plan for (domain, problem), and returns 0.0 if VAL
    validates the plan under the candidate goal (-inf otherwise).
    """

    def __init__(
        self,
        domain_pddl_text: str,
        problem_pddl_text: str,
        fast_downward_cmd: str = "./fast-downward.sif",
        val_cmd: str = "Validate",
        cache_root: Path | str = "benchmark/goal_inference/data",
        verbosity: int = 0,
    ):
        super().__init__(list(range(256)))
        self.domain = domain_pddl_text.replace("\r\n", "\n").replace("\r", "\n")
        self.problem = problem_pddl_text.replace("\r\n", "\n").replace("\r", "\n")
        self.fast_downward_cmd = fast_downward_cmd
        self.val_cmd = val_cmd
        self.verbosity = verbosity

        self.cache_root = Path(cache_root)
        self.tasks_dir = self.cache_root / "pddl_tasks"
        self.plans_dir = self.cache_root / "pddl_plans"
        self.tasks_dir.mkdir(parents=True, exist_ok=True)
        self.plans_dir.mkdir(parents=True, exist_ok=True)

        # Hash both domain + problem to key plan cache
        key = hashlib.sha256(
            (self.domain + "\n---\n" + self.problem).encode("utf-8")
        ).hexdigest()
        self.task_path = self.tasks_dir / f"{key}.pddl"
        self.plan_path = self.plans_dir / f"{key}.pddl"
        self.domain_path = self.tasks_dir / f"{key}.domain.pddl"

        if not self.task_path.exists():
            self.task_path.write_text(self.problem, encoding="utf-8")
        if not self.domain_path.exists():
            self.domain_path.write_text(self.domain, encoding="utf-8")
        if not self.plan_path.exists():
            self._ensure_plan()

    def _ensure_plan(self):
        """Generate and cache a plan with Fast-Downward if missing.

        Raises:
            ValueError: If Fast-Downward exits with a non-zero status.
        """
        cmd = (
            f"{shlex.quote(self.fast_downward_cmd)} "
            f"--plan-file {shlex.quote(str(self.plan_path))} "
            f"{shlex.quote(str(self.domain_path))} {shlex.quote(str(self.task_path))} "
            f'--search "astar(ipdb())"'
        )
        proc = subprocess.run(
            cmd,
            shell=True,
            stdout=subprocess.DEVNULL,
            stderr=subprocess.PIPE,
            encoding="utf-8",
        )
        if proc.returncode != 0:
            raise ValueError(f"[Fast-Downward] nonzero exit:\n{proc.stderr}")

    @staticmethod
    def _fix_context(s: str) -> str | None:
        """Heuristically close extra parentheses to get a parsable goal prefix.

        Args:
            s: Current decoded context string.

        Returns:
            A version with the last balanced ')' ensured, or None if none found.
        """
        head, _, _ = s.rpartition(")")
        if not head:
            return None
        return head + ")"

    @staticmethod
    def _splice_goal(problem_pddl: str, ctx: str) -> str:
        """Replace the goal in `problem_pddl` with `(and {ctx})`.

        Args:
            problem_pddl (str): Full problem text.
            ctx (str): Goal inner content to inject after `(and `.

        Returns:
            str: Modified problem PDDL with the new goal.
        """
        return re.sub(
            r"\(:goal \(and .*?\)\)\n", f"(:goal (and {ctx}))\n", problem_pddl
        )

    def _energy(self, s: str) -> float:
        """Compute potential: 0.0 if VAL validates, else -inf.

        Args:
            s: Current (partial) goal text emitted by the model.

        Returns:
            0.0 on validation success, -inf otherwise (or if unparsable).
        """
        parsable = self._fix_context(s)
        if parsable is None:
            return 0.0
        gen = self._splice_goal(self.problem, parsable)
        with tempfile.TemporaryDirectory(prefix="val_goal_") as tmp:
            tmp_problem = Path(tmp) / "task.pddl"
            tmp_problem.write_text(gen, encoding="utf-8")
            cmd = f"{shlex.quote(self.val_cmd)} {shlex.quote(str(self.domain_path))} {shlex.quote(str(tmp_problem))} {shlex.quote(str(self.plan_path))}"
            proc = subprocess.run(
                cmd,
                shell=True,
                stdout=subprocess.DEVNULL,
                stderr=subprocess.PIPE,
                encoding="utf-8",
            )
            return 0.0 if proc.returncode == 0 else float("-inf")

    async def prefix(self, context: bytes) -> float:
        """Score a partial prefix during generation.

        Args:
            context (bytes): Byte prefix generated so far.

        Returns:
            float: Potential value for the prefix (0.0 / -inf).
        """
        try:
            return self._energy(context.decode("utf-8"))
        except Exception:
            return float("-inf")

    async def complete(self, context: bytes) -> float:
        """Score a completed sequence at EOS.

        Args:
            context: Final byte sequence (without our trailing ')' heuristic).

        Returns:
            Potential value for the complete string (0.0 / -inf).
        """
        try:
            return self._energy(context.decode("utf-8") + ")")
        except Exception:
            return float("-inf")

prefix(context) async

Score a partial prefix during generation.

Parameters:

Name Type Description Default
context bytes

Byte prefix generated so far.

required

Returns:

Name Type Description
float float

Potential value for the prefix (0.0 / -inf).

Source code in genlm/eval/domains/goal_inference/goal_potential.py
async def prefix(self, context: bytes) -> float:
    """Score a partial prefix during generation.

    Args:
        context (bytes): Byte prefix generated so far.

    Returns:
        float: Potential value for the prefix (0.0 / -inf).
    """
    try:
        return self._energy(context.decode("utf-8"))
    except Exception:
        return float("-inf")

complete(context) async

Score a completed sequence at EOS.

Parameters:

Name Type Description Default
context bytes

Final byte sequence (without our trailing ')' heuristic).

required

Returns:

Type Description
float

Potential value for the complete string (0.0 / -inf).

Source code in genlm/eval/domains/goal_inference/goal_potential.py
async def complete(self, context: bytes) -> float:
    """Score a completed sequence at EOS.

    Args:
        context: Final byte sequence (without our trailing ')' heuristic).

    Returns:
        Potential value for the complete string (0.0 / -inf).
    """
    try:
        return self._energy(context.decode("utf-8") + ")")
    except Exception:
        return float("-inf")