Skip to content

evaluation

genlm.eval.domains.spider.spider_eval.evaluation

eval_exec_match_worker(db, p_str, g_str, pred, gold, result_queue)

Worker function to execute SQL queries and compare results. This function is intended to run in a separate process.

Source code in genlm/eval/domains/spider/spider_eval/evaluation.py
def eval_exec_match_worker(db, p_str, g_str, pred, gold, result_queue):
    """
    Worker function to execute SQL queries and compare results.
    This function is intended to run in a separate process.
    """
    conn = None
    cursor = None
    try:
        conn = sqlite3.connect(db)
        cursor = conn.cursor()

        try:
            cursor.execute(p_str)
            p_res = cursor.fetchall()
        except Exception:
            result_queue.put(False)
            return

        cursor.execute(g_str)
        q_res = cursor.fetchall()

        def res_map(res, val_units):
            rmap = {}
            for idx, val_unit in enumerate(val_units):
                key = (
                    tuple(val_unit[1])
                    if not val_unit[2]
                    else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))
                )
                rmap[key] = [r[idx] for r in res]
            return rmap

        p_val_units = [unit[1] for unit in pred["select"][1]]
        q_val_units = [unit[1] for unit in gold["select"][1]]
        result_queue.put(res_map(p_res, p_val_units) == res_map(q_res, q_val_units))

    except Exception:
        result_queue.put(False)

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()

eval_exec_match(db, p_str, g_str, pred, gold, timeout=None)

Execute the eval_exec_match function with a timeout using multiprocessing.

Source code in genlm/eval/domains/spider/spider_eval/evaluation.py
def eval_exec_match(db, p_str, g_str, pred, gold, timeout=None):
    """
    Execute the eval_exec_match function with a timeout using multiprocessing.
    """
    result_queue = multiprocessing.Queue()
    process = multiprocessing.Process(
        target=eval_exec_match_worker, args=(db, p_str, g_str, pred, gold, result_queue)
    )
    process.start()
    if timeout is None:
        process.join()
    else:
        process.join(timeout)

    if process.is_alive():
        process.terminate()
        process.join()
        print(f"Query execution timed out after {timeout} seconds.")
        return False

    return result_queue.get()