Skip to content

ds1000

genlm.eval.domains.ds1000

DS1000Dataset

Bases: Dataset[DS1000Instance]

Dataset for DS-1000 evaluation (Lai et al., 2023).

Source code in genlm/eval/domains/ds1000/ds1000.py
class DS1000Dataset(Dataset[DS1000Instance]):
    """Dataset for DS-1000 evaluation (Lai et al., 2023)."""

    def __init__(self, rows: List[Mapping[str, Any]]):
        self._rows = rows

    def __len__(self) -> int:
        return len(self._rows)

    def __iter__(self) -> Iterator[DS1000Instance]:
        for i, row in enumerate(self._rows):
            yield DS1000Instance(
                prompt=str(row.get("prompt", "")).strip(),
                code_context=str(row.get("code_context", "")).strip(),
                reference_code=row.get("reference_code"),
                metadata=(row.get("metadata") or {}),
                instance_id=i,
            )

    @property
    def schema(self) -> Type[DS1000Instance]:
        return DS1000Instance

    @classmethod
    def from_hf(
        cls,
        split: str = "test",
        libraries: Optional[Sequence[str]] = None,
        perturbation_types: Optional[Sequence[str]] = None,
        max_instances: Optional[int] = None,
        shuffle: bool = False,
        seed: int = 1234,
        cache_dir: Optional[str] = None,
    ) -> "DS1000Dataset":
        """Load and (optionally) filter DS-1000 from Hugging Face."""
        ds = load_dataset("xlangai/DS-1000", split=split, cache_dir=cache_dir)
        rows: List[Mapping[str, Any]] = list(ds)

        lib_set = {x.lower() for x in libraries} if libraries else None
        pt_set = {x.lower() for x in perturbation_types} if perturbation_types else None

        def _keep(r: Mapping[str, Any]) -> bool:
            m = (r.get("metadata") or {})
            lib_ok = True if not lib_set else str(m.get("library", "")).lower() in lib_set
            pt_ok = True if not pt_set else str(m.get("perturbation_type", "")).lower() in pt_set
            return lib_ok and pt_ok

        rows = [r for r in rows if _keep(r)]

        if shuffle:
            rnd = random.Random(seed)
            rnd.shuffle(rows)

        if isinstance(max_instances, int) and max_instances >= 0:
            rows = rows[:max_instances]

        log.info("Loaded DS-1000: %d instances (split=%s)", len(rows), split)
        return cls(rows)

from_hf(split='test', libraries=None, perturbation_types=None, max_instances=None, shuffle=False, seed=1234, cache_dir=None) classmethod

Load and (optionally) filter DS-1000 from Hugging Face.

Source code in genlm/eval/domains/ds1000/ds1000.py
@classmethod
def from_hf(
    cls,
    split: str = "test",
    libraries: Optional[Sequence[str]] = None,
    perturbation_types: Optional[Sequence[str]] = None,
    max_instances: Optional[int] = None,
    shuffle: bool = False,
    seed: int = 1234,
    cache_dir: Optional[str] = None,
) -> "DS1000Dataset":
    """Load and (optionally) filter DS-1000 from Hugging Face."""
    ds = load_dataset("xlangai/DS-1000", split=split, cache_dir=cache_dir)
    rows: List[Mapping[str, Any]] = list(ds)

    lib_set = {x.lower() for x in libraries} if libraries else None
    pt_set = {x.lower() for x in perturbation_types} if perturbation_types else None

    def _keep(r: Mapping[str, Any]) -> bool:
        m = (r.get("metadata") or {})
        lib_ok = True if not lib_set else str(m.get("library", "")).lower() in lib_set
        pt_ok = True if not pt_set else str(m.get("perturbation_type", "")).lower() in pt_set
        return lib_ok and pt_ok

    rows = [r for r in rows if _keep(r)]

    if shuffle:
        rnd = random.Random(seed)
        rnd.shuffle(rows)

    if isinstance(max_instances, int) and max_instances >= 0:
        rows = rows[:max_instances]

    log.info("Loaded DS-1000: %d instances (split=%s)", len(rows), split)
    return cls(rows)

DS1000Evaluator

Bases: Evaluator[DS1000Instance]

Source code in genlm/eval/domains/ds1000/ds1000.py
class DS1000Evaluator(Evaluator[DS1000Instance]):
    def __init__(
            self,
            python_executable: Optional[str] = None,
            timeout_seconds: float = 15.0,
            extra_env: Optional[Dict[str, str]] = None,
            max_log_chars: int = 4000
        ) -> None:
        self.python_executable = python_executable or sys.executable
        self.timeout_seconds = float(timeout_seconds)
        self.extra_env = dict(extra_env or {})
        self.max_log_chars = int(max_log_chars)
        # Markers for detecting PASS/FAIL in output
        self.marker_pass = "<<<DS1000_PASS>>>"
        self.marker_fail = "<<<DS1000_FAIL>>>"

    def assigns_result(self, code: str) -> bool:
        try:
            tree = ast.parse(code)
            for n in ast.walk(tree):
                if isinstance(n, ast.Assign) and any(getattr(t, "id", None) == "result" for t in n.targets):
                    return True
                if isinstance(n, ast.AnnAssign) and getattr(n.target, "id", None) == "result":
                    return True
        except Exception:
            pass
        return False

    def evaluate_sample(self, instance: DS1000Instance, response: str) -> EvaluationResult:
        solution = _postprocess_code(response)
        if not solution:
            return EvaluationResult(score=0.0, desc="empty solution", metadata=instance.metadata)

        script = self._build_harness_script(instance.code_context, solution)
        ok, _, _, _ = self._run_in_subprocess(script)

        # Summarize with clear sections, trim to max_log_chars
        def _trim(s: str) -> str:
            return s if len(s) <= self.max_log_chars else (s[:self.max_log_chars] + "\n...[truncated]")

        desc = _trim(solution)
        return EvaluationResult(score=1.0 if ok else 0.0, desc=desc, metadata=instance.metadata)

    def _build_harness_script(self, code_context: str, solution: str) -> str:
        """Load test_execution() and run it with solution."""
        return textwrap.dedent(f"""
        # -*- coding: utf-8 -*-
        import sys, traceback

        code_context = {code_context!r}
        solution = {solution!r}
        g = {{"__name__": "__main__"}}
        try:
            exec(code_context, g, g)
        except BaseException as e:
            print("{self.marker_fail} HARNESS_EXEC_ERROR:", repr(e), flush=True)
            traceback.print_exc()
            sys.exit(1)

        test_execution = g.get("test_execution")
        if not callable(test_execution):
            print("{self.marker_fail} MISSING_test_execution", flush=True)
            sys.exit(1)

        try:
            _ret = test_execution(solution)
            if _ret is False:
                print("{self.marker_fail} TEST_RETURNED_FALSE", flush=True)
                sys.exit(3)
        except Exception as e:
            print("{self.marker_fail}", repr(e), flush=True)
            traceback.print_exc()
            sys.exit(2)

        print("{self.marker_pass}", flush=True)
        sys.exit(0)
        """).strip()

    def _run_in_subprocess(self, script: str) -> Tuple[bool, int, str, str]:
        with tempfile.TemporaryDirectory(prefix="ds1000_") as td:
            path = os.path.join(td, "harness.py")
            with open(path, "w", encoding="utf-8") as f:
                f.write(script + "\n")

            # Build a sandboxed env
            env = _sandbox_env(td, extra_env={**{"MPLBACKEND": "Agg"}, **self.extra_env})
            cmd = [self.python_executable, "-B", path]

            try:
                proc = subprocess.run(
                    cmd,
                    check=False,
                    capture_output=True,
                    text=True,
                    timeout=self.timeout_seconds,
                    env=env,
                    cwd=td,
                )
            except subprocess.TimeoutExpired:
                return (False, 124, "", f"timeout after {self.timeout_seconds:.1f}s")

            out = proc.stdout or ""
            err = proc.stderr or ""
            rc = int(proc.returncode)
            pass_line = bool(self.marker_pass in out)
            fail_line = bool(self.marker_fail in out) or bool(self.marker_fail in err)
            ok = (rc == 0) and pass_line and (not fail_line)
            return (ok, rc, out.strip(), err.strip())

DS1000Instance

Bases: Instance

Schema for a DS-1000 instance.

Source code in genlm/eval/domains/ds1000/ds1000.py
class DS1000Instance(Instance):
    """Schema for a DS-1000 instance."""

    prompt: str 
    code_context: str 
    metadata: Dict[str, Any]
    reference_code: Optional[str] = None

DS1000RuntimeNoErrorPotential

Bases: Potential

DS-1000 expensive potential: execute the harness on a complete prefix. Return 0.0 if no error, -inf otherwise.

Source code in genlm/eval/domains/ds1000/runtime_no_error_potential.py
class DS1000RuntimeNoErrorPotential(Potential):
    """
    DS-1000 expensive potential: execute the harness on a complete prefix.
    Return 0.0 if no error, -inf otherwise.
    """

    def __init__(
        self,
        vocabulary=None,
        code_context: str = "",
        timeout_seconds: float = 30.0,
        python_executable: Optional[str] = None,
        extra_env: Optional[Dict[str, str]] = None,
        f: Optional[Callable[[List[bytes]], List[bytes]]] = None,
    ):
        vocabulary = vocabulary or [bytes([i]) for i in range(256)]
        super().__init__(vocabulary=vocabulary)
        self.timeout_seconds = float(timeout_seconds)
        self.code_context = code_context
        self.python_executable = python_executable or sys.executable
        self.extra_env = dict(extra_env or {})
        self.last_was_syntax_error = False
        self.f = f

    def coerce(
        self,
        other,
        f: Optional[Callable[[List[bytes]], List[bytes]]] = None,
        prune: bool = True,
    ):
        return DS1000RuntimeNoErrorPotential(
            vocabulary=list(other.vocab),
            code_context=self.code_context,
            timeout_seconds=self.timeout_seconds,
            python_executable=self.python_executable,
            extra_env=self.extra_env,
            f=f,
        )

    def _bytes_to_str(self, toks):
        if not toks:
            return ""
        try:
            bytes_str = b"".join(toks).decode("utf-8", errors="ignore")
        except UnicodeDecodeError:
            bytes_str = b"".join(toks).decode("latin-1", errors="ignore")
        return bytes_str

    async def prefix(self, context: List[bytes]) -> float:
        if self.f is not None:
            context = self.f(context)
        code = self._bytes_to_str(context)
        # Newline guardrail when using the default sampler.
        if not code.endswith("\n"):
            return 0.0
        code = _postprocess_code(code)
        out = await self._score_no_error(code)
        return out

    async def complete(self, context: List[bytes]):
        # Apply transformation before processing
        if self.f is not None:
            context = self.f(context)
        code = self._bytes_to_str(context)
        code = _postprocess_code(code)
        out = await self._score_no_error(code)
        return out

    async def _score_no_error(self, complete_code: str) -> float:
        """
        Run the harness script in a subprocess and check for runtime errors.
        Returns 0.0 if no error (incl. AssertionError), -inf otherwise.

        complete_code: str - the complete code to run
        Returns:
            float - 0.0 if no error, -inf otherwise.
        """
        OK, BAD, SYNTAX = "<<<OK>>>", "<<<BAD>>>", "<<<SYNTAX>>>"

        script = textwrap.dedent(
            f"""
            import sys, warnings, os, ast
            warnings.filterwarnings("ignore")
            os.environ.setdefault("PYTHONWARNINGS", "ignore")
            os.environ.setdefault("MPLBACKEND", "Agg")

            OK, BAD, SYNTAX = "<<<OK>>>", "<<<BAD>>>", "<<<SYNTAX>>>"
            code_context = {self.code_context!r}
            answer = {complete_code!r}
            try:
                g = {{}}
                exec(code_context, g, g)
                te = g.get("test_execution")
                if not callable(te):
                    print(BAD); raise SystemExit(0)
                try: # ast.parse to check if the answer is valid code
                    ast.parse(answer, filename="<answer>", mode="exec")
                except SyntaxError:
                    print(SYNTAX); raise SystemExit(0)
                try:
                    te(answer)
                    # If we get here with no exception, it ran without runtime error.
                    print(OK)
                except AssertionError:
                    print(OK)
                except SyntaxError:
                    print(SYNTAX)
                except Exception:
                    print(BAD)
            except AssertionError: # Safety check for AssertionError
                print(OK)
            except SyntaxError:
                print(SYNTAX)
            except Exception:
                print(BAD)
            """
        ).strip()

        try:
            with tempfile.TemporaryDirectory(prefix="ds1000_rt_") as td:
                path = os.path.join(td, "rt_harness.py")
                with open(path, "w", encoding="utf-8") as f:
                    f.write(script + "\n")

                env = _sandbox_env(
                    td,
                    extra_env={
                        **{"MPLBACKEND": "Agg", "PYTHONWARNINGS": "ignore"},
                        **self.extra_env,
                    },
                )

                proc = subprocess.run(
                    [self.python_executable, "-B", path],
                    check=False,
                    capture_output=True,
                    text=True,
                    timeout=self.timeout_seconds,
                    env=env,
                    cwd=td,
                )
        except subprocess.TimeoutExpired:
            return float("-inf")

        out = (proc.stdout or "") + (proc.stderr or "")
        ok = any(line.strip() == OK for line in out.splitlines())
        bad = any(line.strip() == BAD for line in out.splitlines())
        syntax = any(line.strip() == SYNTAX for line in out.splitlines())

        self.last_was_syntax_error = bool(syntax)
        bad = bad or syntax
        return 0.0 if ok and not bad else float("-inf")