Skip to content

json

is_utf8_start_byte(n)

Checks if this is a byte that can appear at the start of a UTF-8 character.

Source code in genlm/control/potential/built_in/json.py
def is_utf8_start_byte(n: int) -> bool:
    """Checks if this is a byte that can appear at the
    start of a UTF-8 character."""
    assert 0 <= n < 256
    for prefix, mask in UTF8_START_BYTE_MASKS:
        if n & mask == prefix:
            return True
    return False

JustOneBlockIterable

Provides a single value (intended to be bytes from a context) and then signals if the reader tried to read past it. This allows us to distinguish invalid JSON from incomplete JSON by seeing if the reader tried to read more than it had or failed early.

Source code in genlm/control/potential/built_in/json.py
class JustOneBlockIterable:
    """Provides a single value (intended to be bytes from a context)
    and then signals if the reader tried to read past it. This allows
    us to distinguish invalid JSON from incomplete JSON by seeing if
    the reader tried to read more than it had or failed early."""

    def __init__(self, block):
        self.__block = block
        self.read_past_first_block = False

    def __iter__(self):
        yield self.__block
        self.read_past_first_block = True

prune_to_validatable_prefix(context)

We don't want to run the JSON validator on objects that are in the middle of generating a string or a float. We also don't want to run it immediately at the end of a string, or on whitespace changes. This finds us a reasonable prefix that ends at a "logical unit" that makes it a good place to check. We can then cache checks based on the relevant prefix.

Source code in genlm/control/potential/built_in/json.py
def prune_to_validatable_prefix(context):
    """We don't want to run the JSON validator on objects that are in the
    middle of generating a string or a float. We also don't want to run it
    immediately at the end of a string, or on whitespace changes. This finds
    us a reasonable prefix that ends at a "logical unit" that makes it a good
    place to check. We can then cache checks based on the relevant prefix.
    """
    assert isinstance(context, bytes)
    try:
        context.decode("utf-8")
    except UnicodeDecodeError as e:
        if e.reason == "unexpected end of data":
            context = context[: e.start]
        else:
            raise

    for i in range(len(context) - 1, -1, -1):
        if context[i] in b"}],":
            return context[: i + 1]
    return b""

Input

Convenience wrapper to provide a stateful stream-like interface that makes it easier to write parsers.

Source code in genlm/control/potential/built_in/json.py
class Input:
    """Convenience wrapper to provide a stateful stream-like interface
    that makes it easier to write parsers."""

    def __init__(self, incoming: AsyncIterator[str]):
        self.__incoming = incoming
        self.__finished = False
        # There's no textarray equivalent, so we store the growable
        # string as an array of integer codepoints.
        self.buffer = array("I")
        self.__index = 0
        self.__in_preserving_block = False

    @property
    def index(self):
        return self.__index

    @index.setter
    def index(self, value):
        assert value >= self.__index
        self.__index = value

    def __repr__(self):
        buffer = "".join(chr(i) for i in self.buffer)
        i = self.index
        return f"Input({repr(buffer[:i])}, ||, {repr(buffer[i:])})"

    async def advance_input(self):
        if self.__finished:
            return False
        try:
            next_block = await self.__incoming.more()
            self.buffer.extend([ord(c) for c in next_block])
            return True
        except StopAsyncIteration:
            self.__finished = True
            return False

    async def __read_until(self, condition):
        while True:
            if condition():
                break
            if not await self.advance_input():
                raise Incomplete()

    async def read_pattern(self, pattern, group=0):
        await self.__read_until(lambda: self.index < len(self.buffer))
        while True:
            # Having to convert the whole thing to a string here is really
            # annoying, but in practice the inefficiency is dwarfed by the LLM
            # so hopefully we don't have to worry about it.
            buffer = "".join(chr(i) for i in self.buffer[self.index :])
            match = pattern.match(buffer, pos=0, partial=True)
            if match is None or (result := match.group(group)) is None:
                raise ParseError()
            elif match.partial:
                if not await self.advance_input():
                    raise Incomplete()
            else:
                self.index += match.end()
                return result

    async def get_partial_pattern(self, pattern):
        """If the remainder of the buffer read so far could match a prefix
        of pattern, or start with a complete match for the pattern return it.

        Note: This is pure lookahead and does *not* advance the input."""

        await self.__read_until(lambda: self.index < len(self.buffer))
        buffer = "".join(chr(i) for i in self.buffer[self.index :])
        return pattern.match(buffer, pos=0, partial=True)

    async def current_char(self):
        await self.__read_until(lambda: self.index < len(self.buffer))
        return chr(self.buffer[self.index])

    async def read(self, n) -> str:
        await self.__read_until(lambda: self.index + n <= len(self.buffer))
        result = self.buffer[self.index : self.index + n]
        assert len(result) == n
        self.index += n
        return "".join(map(chr, result))

    async def expect(self, expected: str):
        for c in expected:
            actual = await self.read(1)
            if actual != c:
                raise ParseError(
                    f"Expected: {c} but got {actual} at index {self.index - 1}"
                )

    @contextmanager
    def preserving_index(self):
        """Only advance the index if the operation in the context block does
        not error."""
        start = self.index
        try:
            yield
        except Exception:
            self.__index = start
            raise

    @contextmanager
    def resetting_index(self):
        """Always reset the index to where it started at the end of this block."""
        start = self.index
        try:
            yield
        finally:
            self.__index = start

    async def parse(self, parser: "Parser[T]") -> T:
        with self.preserving_index():
            return await parser.parse(self)

    async def skip_whitespace(self):
        if self.index == len(self.buffer):
            if not await self.advance_input():
                return
        try:
            await self.parse(WHITESPACE_PARSER)
        except Incomplete:
            pass

get_partial_pattern(pattern) async

If the remainder of the buffer read so far could match a prefix of pattern, or start with a complete match for the pattern return it.

Note: This is pure lookahead and does not advance the input.

Source code in genlm/control/potential/built_in/json.py
async def get_partial_pattern(self, pattern):
    """If the remainder of the buffer read so far could match a prefix
    of pattern, or start with a complete match for the pattern return it.

    Note: This is pure lookahead and does *not* advance the input."""

    await self.__read_until(lambda: self.index < len(self.buffer))
    buffer = "".join(chr(i) for i in self.buffer[self.index :])
    return pattern.match(buffer, pos=0, partial=True)

preserving_index()

Only advance the index if the operation in the context block does not error.

Source code in genlm/control/potential/built_in/json.py
@contextmanager
def preserving_index(self):
    """Only advance the index if the operation in the context block does
    not error."""
    start = self.index
    try:
        yield
    except Exception:
        self.__index = start
        raise

resetting_index()

Always reset the index to where it started at the end of this block.

Source code in genlm/control/potential/built_in/json.py
@contextmanager
def resetting_index(self):
    """Always reset the index to where it started at the end of this block."""
    start = self.index
    try:
        yield
    finally:
        self.__index = start

Parser

Bases: Generic[T]

Very basic parser combinators for mostly unambiguous grammars.

Source code in genlm/control/potential/built_in/json.py
class Parser(Generic[T]):
    """Very basic parser combinators for mostly unambiguous grammars."""

    async def parse(self, input: Input) -> T: ...

    async def parse_string(self, s: str) -> T:
        return await Input(TrivialSource(s)).parse(self)

    def __floordiv__(self, other: Generic[S]) -> "Parser[Union[T, S]]":
        return AltParser(self, other)

    def drop_result(self) -> "Parser[None]":
        return self.map(lambda x: None)

    def map(self, apply: Callable[[T], S]) -> "Parser[S]":
        return MapParser(self, apply)

    def filter(self, predicate: Callable[[T], bool]) -> "Parser[T]":
        return FilterParser(self, predicate)

FixedSetParser

Bases: Parser[str]

Parser that matches a precise set of strings, some of which might be prefixes of each other, always returning the longest matching one.

Source code in genlm/control/potential/built_in/json.py
class FixedSetParser(Parser[str]):
    """Parser that matches a precise set of strings, some of which might
    be prefixes of each other, always returning the longest matching one."""

    def __init__(self, values):
        super().__init__()
        if not values:
            raise ValueError("values for FixedSetParser cannot be empty")
        self.trie = PatriciaTrie(values)
        self.values = values

    def __repr__(self):
        return f"FixedSetParser({self.values})"

    async def parse(self, input: Input) -> str:
        start = input.index
        match_length = -1
        node = self.trie.root

        with input.resetting_index():
            while True:
                assert node.accepting or node.children or node is self.trie.root

                try:
                    await input.expect(node.prefix)
                except (Incomplete, ParseError):
                    if match_length < 0:
                        raise
                    else:
                        break
                if node.accepting:
                    match_length = input.index - start
                    if not node.children:
                        break
                try:
                    c = await input.read(1)
                except Incomplete:
                    if match_length < 0:
                        raise
                    else:
                        break
                try:
                    node = node.children[c]
                except KeyError:
                    if match_length < 0:
                        raise ParseError(
                            f"Unexpected character {c}. Expected one of {repr(''.join(node.children))}"
                        )
                    break

        # Should have errored in the loop if not
        assert match_length >= 0
        result = await input.read(match_length)
        assert input.index == start + match_length
        return result