Text to SQL (Spider)¶
This example shows how to evaluate a genlm.control model on the Spider domain.
- Task: Generate SQL queries from a natural language question paired with its corresponding database schema.
- Data: Development split of the Spider dataset (Yu et al., 2018).
Setup¶
First, install the dependencies for this domain. In the root directory, run:
pip install -e .[spider]
Download the punkt_tab data for nltk:
python -m nltk.downloader punkt_tab
To run the full spider evaluation, download the spider dataset via:
gdown 'https://drive.google.com/u/0/uc?id=1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J&export=download'
unzip spider_data.zip
For this example, we'll use the assets/spider/spider_sample directory which contains a small subset of the spider dataset.
In this example, we'll also use the grammars provided in assets/spider/grammars.json. This is a json file that maps each SQL schema name to a lark grammar.
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid hugginface warnings
Usage¶
This example shows how to evaluate a genlm.control model on spider.
Initialize the dataset and evaluator¶
from genlm.eval.domains.spider import SpiderDataset, SpiderEvaluator
/opt/homebrew/Caskroom/miniconda/base/envs/genlm/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
spider_data_dir = "../../../assets/spider/spider_sample" # Replace with your path to the spider dataset
spider_grammars = "../../../assets/spider/grammars.json" # Replace with your path to the spider grammars
dataset = SpiderDataset.from_spider_dir(
spider_data_dir, grammar_json_path=spider_grammars, few_shot_example_ids=[0, 1]
)
evaluator = SpiderEvaluator(spider_data_dir)
Define a model adaptor¶
A model adaptor is an async callable that takes a PatternMatchingInstance and returns a ModelOutput. For this example, we'll use a constrained genlm.control.PromptedLLM to generate responses.
from genlm.control import PromptedLLM, AWRS, BoolCFG
from genlm.eval import ModelOutput, ModelResponse
from genlm.eval.domains.spider import default_prompt_formatter
# Load an LLM
LLM = PromptedLLM.from_name("gpt2", eos_tokens=[b"\n", b"\n\n"])
async def model(instance, output_dir, replicate):
# Set the prompt for the LLM.
LLM.prompt_ids = default_prompt_formatter(
LLM.model.tokenizer, instance, use_chat_format=False
)
# Define a potential that ensures the generated text matches the pattern
potential = BoolCFG.from_lark(instance.lark_grammar).coerce(LLM, f=b"".join)
# Define an adaptive weighted rejection sampler to sample tokens from the constrained model.
sampler = AWRS(LLM, potential)
# Run SMC to sample sequences from the constrained model.
sequences = await sampler.smc(
n_particles=2,
ess_threshold=0.5,
max_tokens=100,
)
# Return the sampled sequences and their probabilities as a ModelOutput.
return ModelOutput(
responses=[
ModelResponse(response=sequence, weight=prob)
for sequence, prob in sequences.decoded_posterior.items()
],
)
/opt/homebrew/Caskroom/miniconda/base/envs/genlm/lib/python3.11/site-packages/genlm/backend/tokenization/vocab.py:98: UserWarning: Duplicate tokens found in string vocabulary. This may lead to downstream issues with the string vocabulary; we recommend using the byte vocabulary. warnings.warn(
Run the evaluation¶
from genlm.eval import run_evaluation
results = await run_evaluation(
dataset=dataset,
model=model,
evaluator=evaluator,
max_instances=2,
n_replicates=1,
verbosity=1,
# output_dir="spider_results", optionally save the results to a directory
)
Instance utterance: How many singers do we have?, schema_name: concert_singer (id: 0) Mean weighted accuracy (instance): 0.9534933025699144 Mean weighted accuracy (total): 0.9534933025699144 Instance utterance: What is the total number of singers?, schema_name: concert_singer (id: 1) Mean weighted accuracy (instance): 0.9836474929075786 Mean weighted accuracy (total): 0.9685703977387465
results.keys()
dict_keys(['average_weighted_accuracy', 'n_instances', 'all_instance_results', 'all_instance_outputs'])
results["all_instance_outputs"]
[[ModelOutput(responses=[ModelResponse(response=' SELECT count(*) FROM singer', weight=0.9534933025699144, metadata=None), ModelResponse(response=' select Stadium_ID, Stadium_ID, Stadium_ID, Stadium_ID, stadium_ID, Stadium_ID, Stadium_ID + 1, Stadium_ID', weight=0.04650669743008568, metadata=None)], runtime_seconds=None, metadata=None)], [ModelOutput(responses=[ModelResponse(response=' SELECT count(*) FROM singer', weight=0.9836474929075786, metadata=None), ModelResponse(response=' select Stadium_ID, Stadium_ID, Stadium_ID, Stadium_ID, stadium_ID, Stadium_ID, Stadium_ID + 1, Stadium_ID', weight=0.016352507092421423, metadata=None)], runtime_seconds=None, metadata=None)]]
References¶
Tao Yu, Rui Zhang, Kai Yang, Michihiro Yasunaga, Dongxu Wang, Zifan Li, James Ma, Irene Li, Qingning Yao, Shanelle Roman, Zilin Zhang, and Dragomir Radev. Spider: A large-scale human-labeled dataset for complex and cross-domain semantic parsing and text-to-SQL task. In Proceedings of the Conference on Empirical Methods in Natural Language Processing, 2018. URL https://aclanthology.org/D18-1425.