Skip to content

trie

TokenCharacterTrie

A trie data structure for efficient token-to-character mapping.

Source code in genlm/backend/trie/base.py
class TokenCharacterTrie:
    """A trie data structure for efficient token-to-character mapping."""

    def __init__(self, decode):
        """Initialize a `TokenCharacterTrie`.

        Args:
            decode (list): List representing the token vocabulary.
                Each element of the list must be iterable.
        """
        self.decode = decode
        self.word2leaf = {}
        self.children = [{}]  # First node is root
        self.root = 0
        self.token_id_to_leaf = []

        for token_id, word in enumerate(self.decode):
            curr = self.root
            for letter in word:
                if letter not in self.children[curr]:
                    self.children[curr][letter] = len(self.children)
                    self.children.append({})
                curr = self.children[curr][letter]

            self.children[curr][None] = last = len(self.children)
            self.children.append({})
            assert word not in self.word2leaf, (
                "Can't have duplicate words in vocabulary"
            )
            self.word2leaf[word] = last

            self.token_id_to_leaf.append((token_id, last))

        self.leaf2word = dict(zip(self.word2leaf.values(), self.word2leaf.keys()))
        self.jump = List(
            [np.array(sorted(x.values()), dtype=np.int32) for x in self.children]
        )
        self.ordering = np.array(list(self._order(self.root)), np.int32)

        # Renumber the states of the trie so that they are named by a contiguous
        # range of integers and those integers respect the are topologically
        # ordering of the trie topology.  This improves the efficiency of the
        # updating the trie as it improves memory locality.
        ordering = {}
        for i, x in enumerate(self._order_full(self.root)):
            ordering[x] = i
        self._rename(f=lambda x: ordering[x])

        node2prefix = {self.root: []}
        for x in reversed(range(len(self.children))):
            for letter, y in self.children[x].items():
                if letter is None:
                    node2prefix[y] = node2prefix[x]
                else:
                    node2prefix[y] = node2prefix[x] + [letter]
        self.node2prefix = node2prefix

    def _rename(self, f):
        """Rename all node indices in the trie using the provided mapping function.

        Args:
            f (callable): Function that maps old node indices to new node indices
        """
        N = len(self.children)

        new_children = [{} for _ in range(N)]
        nodes = range(N)

        for x in nodes:
            for letter, y in self.children[x].items():
                new_children[f(x)][letter] = f(y)

        self.root = f(self.root)
        self.children = new_children
        self.word2leaf = {w: f(x) for w, x in self.word2leaf.items()}
        self.leaf2word = dict(zip(self.word2leaf.values(), self.word2leaf.keys()))

        self.token_id_to_leaf = np.array(
            [(i, f(x)) for i, x in self.token_id_to_leaf], dtype=np.int32
        )

        self.ordering = np.array([f(x) for x in self.ordering])
        self.jump = List(
            [np.array(sorted(x.values()), dtype=np.int32) for x in new_children]
        )

    def _alloc_weights(self):
        """Allocate an array to store weight values for all nodes.

        Returns:
            np.ndarray: Zero-initialized array for storing weight values
        """
        return np.zeros(len(self.children), dtype=np.float64)

    def _preprocess_ws(self, ws):
        """Preprocess the weight vector to ensure it is a numpy array and on the correct device.

        Args:
            ws (torch.Tensor|np.ndarray): Token weights over the vocabulary of shape `(len(self.decode),)`

        Returns:
            (np.ndarray): Weight vector
        """
        if isinstance(ws, torch.Tensor):
            if ws.device.type != "cpu":
                ws = ws.cpu()
            ws = ws.numpy()
        return ws

    def weight_sum(self, ws):
        """Compute weight sum for each node in the trie.

        For each node in the trie, this computes the sum of weights of all leaf nodes (tokens)
        that are descendants of that node.

        Args:
            ws (torch.Tensor|np.ndarray): Token weights over the vocabulary of shape `(len(self.decode),)`

        Returns:
            (np.ndarray): Summed weights for each node in the trie.
        """
        ws = self._preprocess_ws(ws)
        node_ws = self._alloc_weights()
        _update_trie_numba_sum(
            node_ws=node_ws,
            ws=ws,
            token_id_to_leaf=self.token_id_to_leaf,
            jump=self.jump,
            ordering=self.ordering,
        )
        return node_ws

    def weight_max(self, ws):
        """Compute weight max for each node in the trie.

        For each node in the trie, this computes the maximum weight among all leaf nodes (tokens)
        that are descendants of that node.

        Args:
            ws (torch.Tensor|np.ndarray): Token weights over the vocabulary of shape `(len(self.decode),)`

        Returns:
            (np.ndarray): Weight max values for each node in the trie.
        """
        ws = self._preprocess_ws(ws)
        node_ws = self._alloc_weights()
        _update_trie_numba_max(
            node_ws=node_ws,
            ws=ws,
            token_id_to_leaf=self.token_id_to_leaf,
            jump=self.jump,
            ordering=self.ordering,
        )
        return node_ws

    def batch_weight_sum(self, ws):
        """Batched equivalent of `weight_sum`.

        Args:
            ws (list[torch.Tensor|np.ndarray]): Batch of token weights, each of shape `(len(self.decode),)`

        Returns:
            (np.ndarray): Batch of weight values of `len(ws)` for each node in the trie
        """
        return np.array([self.weight_sum(ws) for ws in ws])

    def batch_weight_max(self, ws):
        """Batched equivalent of `weight_max`.

        Args:
            ws (list[torch.Tensor|np.ndarray]): Batch of token weights, each of shape `(len(self.decode),)`

        Returns:
            (np.ndarray): Batch of weight max values of `len(ws)` for each node in the trie
        """
        return np.array([self.weight_max(ws) for ws in ws])

    def _order(self, node):
        """Generate a topological ordering of nodes beneath the given node.

        Args:
            node (int): Starting node index

        Yields:
            int: Node indices in topological order
        """
        for a in self.children[node]:
            if a is None:
                pass
            else:
                yield from self._order(self.children[node][a])
        yield node

    def _order_full(self, node):
        """Generate a complete topological ordering including all child nodes.

        Args:
            node (int): Starting node index

        Yields:
            (int): Node indices in complete topological order
        """
        for a in self.children[node]:
            yield from self._order_full(self.children[node][a])
        yield node

    def visualize(self, ws=None):
        """Visualize the trie structure using Graphviz.

        Args:
            ws (np.ndarray|None): Optional weight vector to display at each node.
                                Should be of length `len(self.children)`.

        Returns:
            (graphviz.Digraph): The generated graph object
        """
        try:
            import graphviz
        except ImportError:  # pragma: no cover
            raise ImportError(
                "Please install graphviz: pip install graphviz"
            )  # pragma: no cover

        if ws is not None and len(ws) != len(self.children):
            raise ValueError(
                f"Weight vector length ({len(ws)}) must match number of nodes ({len(self.children)})"
            )

        dot = graphviz.Digraph(comment="Token Character Trie")
        dot.attr(rankdir="LR")

        # Create a subgraph for the legend
        with dot.subgraph(name="cluster_legend") as legend:
            legend.attr(label="Legend", fontsize="10")
            legend.attr("node", fontsize="7", width="0.1", height="0.1")

            # Example internal node
            legend.node(
                "legend_internal",
                "Internal Node ID\n'Prefix'\nWeight (if provided)",
                shape="circle",
            )

            # Example leaf node
            legend.node("legend_leaf", "Complete Token", shape="doublecircle")

            legend.edge(
                "legend_internal",
                "legend_leaf",
                label="Token item",
                fontsize="10",
            )

            # Align legend horizontally
            legend.attr(rankdir="TB")
            legend.attr(rank="same")

        # Add the main trie nodes and edges
        for node_id in range(len(self.children)):
            prefix = self.node2prefix[node_id]

            if ws is not None:
                label = f"{node_id}\n'{prefix}'\n{ws[node_id]:.4f}"
            else:
                label = f"{node_id}\n'{prefix}'"

            # Color nodes based on mass if provided
            if ws is not None:
                max_ws = ws.max()
                if max_ws > 0:
                    intensity = int(255 * (1 - ws[node_id] / max_ws))
                    color = f"#{intensity:02x}{255:02x}{intensity:02x}"
                else:
                    color = "#ffffff"  # white for zero mass
            else:
                color = "#ffffff"  # default white

            if node_id in self.leaf2word:
                dot.node(
                    str(node_id),
                    label,
                    shape="doublecircle",
                    style="filled",
                    fillcolor=color,
                )
            else:
                dot.node(
                    str(node_id), label, shape="circle", style="filled", fillcolor=color
                )

        for node_id, children in enumerate(self.children):
            for char, child_id in children.items():
                if char is not None:
                    edge_label = str(char)
                else:
                    edge_label = "End-of-Token"

                dot.edge(str(node_id), str(child_id), label=edge_label)

        return dot

__init__(decode)

Initialize a TokenCharacterTrie.

Parameters:

Name Type Description Default
decode list

List representing the token vocabulary. Each element of the list must be iterable.

required
Source code in genlm/backend/trie/base.py
def __init__(self, decode):
    """Initialize a `TokenCharacterTrie`.

    Args:
        decode (list): List representing the token vocabulary.
            Each element of the list must be iterable.
    """
    self.decode = decode
    self.word2leaf = {}
    self.children = [{}]  # First node is root
    self.root = 0
    self.token_id_to_leaf = []

    for token_id, word in enumerate(self.decode):
        curr = self.root
        for letter in word:
            if letter not in self.children[curr]:
                self.children[curr][letter] = len(self.children)
                self.children.append({})
            curr = self.children[curr][letter]

        self.children[curr][None] = last = len(self.children)
        self.children.append({})
        assert word not in self.word2leaf, (
            "Can't have duplicate words in vocabulary"
        )
        self.word2leaf[word] = last

        self.token_id_to_leaf.append((token_id, last))

    self.leaf2word = dict(zip(self.word2leaf.values(), self.word2leaf.keys()))
    self.jump = List(
        [np.array(sorted(x.values()), dtype=np.int32) for x in self.children]
    )
    self.ordering = np.array(list(self._order(self.root)), np.int32)

    # Renumber the states of the trie so that they are named by a contiguous
    # range of integers and those integers respect the are topologically
    # ordering of the trie topology.  This improves the efficiency of the
    # updating the trie as it improves memory locality.
    ordering = {}
    for i, x in enumerate(self._order_full(self.root)):
        ordering[x] = i
    self._rename(f=lambda x: ordering[x])

    node2prefix = {self.root: []}
    for x in reversed(range(len(self.children))):
        for letter, y in self.children[x].items():
            if letter is None:
                node2prefix[y] = node2prefix[x]
            else:
                node2prefix[y] = node2prefix[x] + [letter]
    self.node2prefix = node2prefix

weight_sum(ws)

Compute weight sum for each node in the trie.

For each node in the trie, this computes the sum of weights of all leaf nodes (tokens) that are descendants of that node.

Parameters:

Name Type Description Default
ws Tensor | ndarray

Token weights over the vocabulary of shape (len(self.decode),)

required

Returns:

Type Description
ndarray

Summed weights for each node in the trie.

Source code in genlm/backend/trie/base.py
def weight_sum(self, ws):
    """Compute weight sum for each node in the trie.

    For each node in the trie, this computes the sum of weights of all leaf nodes (tokens)
    that are descendants of that node.

    Args:
        ws (torch.Tensor|np.ndarray): Token weights over the vocabulary of shape `(len(self.decode),)`

    Returns:
        (np.ndarray): Summed weights for each node in the trie.
    """
    ws = self._preprocess_ws(ws)
    node_ws = self._alloc_weights()
    _update_trie_numba_sum(
        node_ws=node_ws,
        ws=ws,
        token_id_to_leaf=self.token_id_to_leaf,
        jump=self.jump,
        ordering=self.ordering,
    )
    return node_ws

weight_max(ws)

Compute weight max for each node in the trie.

For each node in the trie, this computes the maximum weight among all leaf nodes (tokens) that are descendants of that node.

Parameters:

Name Type Description Default
ws Tensor | ndarray

Token weights over the vocabulary of shape (len(self.decode),)

required

Returns:

Type Description
ndarray

Weight max values for each node in the trie.

Source code in genlm/backend/trie/base.py
def weight_max(self, ws):
    """Compute weight max for each node in the trie.

    For each node in the trie, this computes the maximum weight among all leaf nodes (tokens)
    that are descendants of that node.

    Args:
        ws (torch.Tensor|np.ndarray): Token weights over the vocabulary of shape `(len(self.decode),)`

    Returns:
        (np.ndarray): Weight max values for each node in the trie.
    """
    ws = self._preprocess_ws(ws)
    node_ws = self._alloc_weights()
    _update_trie_numba_max(
        node_ws=node_ws,
        ws=ws,
        token_id_to_leaf=self.token_id_to_leaf,
        jump=self.jump,
        ordering=self.ordering,
    )
    return node_ws

batch_weight_sum(ws)

Batched equivalent of weight_sum.

Parameters:

Name Type Description Default
ws list[Tensor | ndarray]

Batch of token weights, each of shape (len(self.decode),)

required

Returns:

Type Description
ndarray

Batch of weight values of len(ws) for each node in the trie

Source code in genlm/backend/trie/base.py
def batch_weight_sum(self, ws):
    """Batched equivalent of `weight_sum`.

    Args:
        ws (list[torch.Tensor|np.ndarray]): Batch of token weights, each of shape `(len(self.decode),)`

    Returns:
        (np.ndarray): Batch of weight values of `len(ws)` for each node in the trie
    """
    return np.array([self.weight_sum(ws) for ws in ws])

batch_weight_max(ws)

Batched equivalent of weight_max.

Parameters:

Name Type Description Default
ws list[Tensor | ndarray]

Batch of token weights, each of shape (len(self.decode),)

required

Returns:

Type Description
ndarray

Batch of weight max values of len(ws) for each node in the trie

Source code in genlm/backend/trie/base.py
def batch_weight_max(self, ws):
    """Batched equivalent of `weight_max`.

    Args:
        ws (list[torch.Tensor|np.ndarray]): Batch of token weights, each of shape `(len(self.decode),)`

    Returns:
        (np.ndarray): Batch of weight max values of `len(ws)` for each node in the trie
    """
    return np.array([self.weight_max(ws) for ws in ws])

visualize(ws=None)

Visualize the trie structure using Graphviz.

Parameters:

Name Type Description Default
ws ndarray | None

Optional weight vector to display at each node. Should be of length len(self.children).

None

Returns:

Type Description
Digraph

The generated graph object

Source code in genlm/backend/trie/base.py
def visualize(self, ws=None):
    """Visualize the trie structure using Graphviz.

    Args:
        ws (np.ndarray|None): Optional weight vector to display at each node.
                            Should be of length `len(self.children)`.

    Returns:
        (graphviz.Digraph): The generated graph object
    """
    try:
        import graphviz
    except ImportError:  # pragma: no cover
        raise ImportError(
            "Please install graphviz: pip install graphviz"
        )  # pragma: no cover

    if ws is not None and len(ws) != len(self.children):
        raise ValueError(
            f"Weight vector length ({len(ws)}) must match number of nodes ({len(self.children)})"
        )

    dot = graphviz.Digraph(comment="Token Character Trie")
    dot.attr(rankdir="LR")

    # Create a subgraph for the legend
    with dot.subgraph(name="cluster_legend") as legend:
        legend.attr(label="Legend", fontsize="10")
        legend.attr("node", fontsize="7", width="0.1", height="0.1")

        # Example internal node
        legend.node(
            "legend_internal",
            "Internal Node ID\n'Prefix'\nWeight (if provided)",
            shape="circle",
        )

        # Example leaf node
        legend.node("legend_leaf", "Complete Token", shape="doublecircle")

        legend.edge(
            "legend_internal",
            "legend_leaf",
            label="Token item",
            fontsize="10",
        )

        # Align legend horizontally
        legend.attr(rankdir="TB")
        legend.attr(rank="same")

    # Add the main trie nodes and edges
    for node_id in range(len(self.children)):
        prefix = self.node2prefix[node_id]

        if ws is not None:
            label = f"{node_id}\n'{prefix}'\n{ws[node_id]:.4f}"
        else:
            label = f"{node_id}\n'{prefix}'"

        # Color nodes based on mass if provided
        if ws is not None:
            max_ws = ws.max()
            if max_ws > 0:
                intensity = int(255 * (1 - ws[node_id] / max_ws))
                color = f"#{intensity:02x}{255:02x}{intensity:02x}"
            else:
                color = "#ffffff"  # white for zero mass
        else:
            color = "#ffffff"  # default white

        if node_id in self.leaf2word:
            dot.node(
                str(node_id),
                label,
                shape="doublecircle",
                style="filled",
                fillcolor=color,
            )
        else:
            dot.node(
                str(node_id), label, shape="circle", style="filled", fillcolor=color
            )

    for node_id, children in enumerate(self.children):
        for char, child_id in children.items():
            if char is not None:
                edge_label = str(char)
            else:
                edge_label = "End-of-Token"

            dot.edge(str(node_id), str(child_id), label=edge_label)

    return dot

AsyncTokenCharacterTrie

An asynchronous wrapper for TokenCharacterTrie implementations that provides automatic request batching.

Source code in genlm/backend/trie/async_impl.py
class AsyncTokenCharacterTrie:
    """An asynchronous wrapper for TokenCharacterTrie implementations that provides automatic request batching."""

    def __init__(self, trie):
        """Initialize an `AsyncTokenCharacterTrie`.

        Args:
            trie (TokenCharacterTrie|ParallelTokenCharacterTrie): The underlying `TokenCharacterTrie` or `ParallelTokenCharacterTrie` instance
        """
        self.trie = trie
        self._queue = None
        self._task = None

    @classmethod
    def from_vocab(cls, vocab, backend="parallel", **kwargs):
        """Creates an `AsyncTokenCharacterTrie` from a vocabulary.

        Args:
            vocab (list): The vocabulary over which the trie will be defined.
            backend (str, optional): The trie implementation to use - either 'sequential' or 'parallel'.
                    Defaults to 'parallel' which uses GPU acceleration when available.
            **kwargs: Additional arguments passed to the trie constructor

        Returns:
            (AsyncTokenCharacterTrie): The initialized asynchronous trie instance.
        """
        if backend == "sequential":
            trie = TokenCharacterTrie(decode=vocab, **kwargs)
        elif backend == "parallel":
            trie = ParallelTokenCharacterTrie(decode=vocab, **kwargs)
        else:
            raise ValueError(
                f"Unknown backend: {backend}. Must be one of ['sequential', 'parallel']"
            )
        return cls(trie)

    async def _queue_request(self, request, op):
        if not self._task or self._task.done():
            self.start()

        future = asyncio.Future()
        await self._queue.put((request, future, op))
        return future

    async def weight_sum(self, ws):
        """Queue a `weight_sum` request. Multiple concurrent calls will be automatically batched
        together.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).

        Returns:
            (np.ndarray): The calculated mass sums for the given distribution.
        """
        future = await self._queue_request(ws, "sum")
        result = await future
        return result

    async def weight_max(self, ws):
        """Queue a `weight_max` request. Multiple concurrent calls will be automatically batched
        together.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).

        Returns:
            (np.ndarray): The calculated max weights for the given distribution.
        """
        future = await self._queue_request(ws, "max")
        result = await future
        return result

    def start(self):
        """Start the background processing task if not already running."""
        if not self._task or self._task.done():
            self._queue = (
                asyncio.Queue()
            )  # Create a new queue so that it is bound to the current event loop
            self._task = asyncio.create_task(self._background_loop())

    def _do_weight_sums(self, batch_weights):
        return self.trie.batch_weight_sum(batch_weights)

    def _do_weight_maxs(self, batch_weights):
        return self.trie.batch_weight_max(batch_weights)

    async def _background_loop(self):
        """Background task that processes queued weight sum and max requests.

        Continuously monitors the queue for new requests and processes them in batches
        using the underlying trie implementation.

        Raises:
            Exception: If any error occurs during processing, it is propagated to all
                      pending futures in the current batch.
        """
        while True:
            try:
                op_groups = defaultdict(list)

                request, future, op = await self._queue.get()
                op_groups[op].append((request, future))

                while not self._queue.empty():
                    request, future, op = await self._queue.get()
                    op_groups[op].append((request, future))

                for op, group in op_groups.items():
                    requests, futures = zip(*group)

                    if op == "sum":
                        logger.debug(f"processing {len(requests)} sum requests")
                        results = self._do_weight_sums(requests)
                    elif op == "max":
                        logger.debug(f"processing {len(requests)} max requests")
                        results = self._do_weight_maxs(requests)
                    else:
                        raise ValueError(f"Unknown operation: {op}")

                    for future, result in zip(futures, results):
                        future.set_result(result)

            except Exception as e:
                for group in op_groups.values():
                    for _, future in group:
                        if not future.done():
                            future.set_exception(e)
                raise

    async def cleanup(self):
        """Async cleanup - preferred method"""
        if self._task and not self._task.done():
            self._task.cancel()
            try:
                await self._task
            except asyncio.CancelledError:
                pass
            self._task = None

    def shutdown(self):
        """Stop the background processing task and cleanup resources."""
        if self._task is not None:
            try:
                self._task.cancel()
            except RuntimeError:
                # Ignore runtime errors that might occur if event loop is closed
                pass
            self._task = None

    def __del__(self):
        self.shutdown()

__init__(trie)

Initialize an AsyncTokenCharacterTrie.

Parameters:

Name Type Description Default
trie TokenCharacterTrie | ParallelTokenCharacterTrie

The underlying TokenCharacterTrie or ParallelTokenCharacterTrie instance

required
Source code in genlm/backend/trie/async_impl.py
def __init__(self, trie):
    """Initialize an `AsyncTokenCharacterTrie`.

    Args:
        trie (TokenCharacterTrie|ParallelTokenCharacterTrie): The underlying `TokenCharacterTrie` or `ParallelTokenCharacterTrie` instance
    """
    self.trie = trie
    self._queue = None
    self._task = None

from_vocab(vocab, backend='parallel', **kwargs) classmethod

Creates an AsyncTokenCharacterTrie from a vocabulary.

Parameters:

Name Type Description Default
vocab list

The vocabulary over which the trie will be defined.

required
backend str

The trie implementation to use - either 'sequential' or 'parallel'. Defaults to 'parallel' which uses GPU acceleration when available.

'parallel'
**kwargs

Additional arguments passed to the trie constructor

{}

Returns:

Type Description
AsyncTokenCharacterTrie

The initialized asynchronous trie instance.

Source code in genlm/backend/trie/async_impl.py
@classmethod
def from_vocab(cls, vocab, backend="parallel", **kwargs):
    """Creates an `AsyncTokenCharacterTrie` from a vocabulary.

    Args:
        vocab (list): The vocabulary over which the trie will be defined.
        backend (str, optional): The trie implementation to use - either 'sequential' or 'parallel'.
                Defaults to 'parallel' which uses GPU acceleration when available.
        **kwargs: Additional arguments passed to the trie constructor

    Returns:
        (AsyncTokenCharacterTrie): The initialized asynchronous trie instance.
    """
    if backend == "sequential":
        trie = TokenCharacterTrie(decode=vocab, **kwargs)
    elif backend == "parallel":
        trie = ParallelTokenCharacterTrie(decode=vocab, **kwargs)
    else:
        raise ValueError(
            f"Unknown backend: {backend}. Must be one of ['sequential', 'parallel']"
        )
    return cls(trie)

weight_sum(ws) async

Queue a weight_sum request. Multiple concurrent calls will be automatically batched together.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.trie.decode),).

required

Returns:

Type Description
ndarray

The calculated mass sums for the given distribution.

Source code in genlm/backend/trie/async_impl.py
async def weight_sum(self, ws):
    """Queue a `weight_sum` request. Multiple concurrent calls will be automatically batched
    together.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).

    Returns:
        (np.ndarray): The calculated mass sums for the given distribution.
    """
    future = await self._queue_request(ws, "sum")
    result = await future
    return result

weight_max(ws) async

Queue a weight_max request. Multiple concurrent calls will be automatically batched together.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.trie.decode),).

required

Returns:

Type Description
ndarray

The calculated max weights for the given distribution.

Source code in genlm/backend/trie/async_impl.py
async def weight_max(self, ws):
    """Queue a `weight_max` request. Multiple concurrent calls will be automatically batched
    together.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).

    Returns:
        (np.ndarray): The calculated max weights for the given distribution.
    """
    future = await self._queue_request(ws, "max")
    result = await future
    return result

start()

Start the background processing task if not already running.

Source code in genlm/backend/trie/async_impl.py
def start(self):
    """Start the background processing task if not already running."""
    if not self._task or self._task.done():
        self._queue = (
            asyncio.Queue()
        )  # Create a new queue so that it is bound to the current event loop
        self._task = asyncio.create_task(self._background_loop())

cleanup() async

Async cleanup - preferred method

Source code in genlm/backend/trie/async_impl.py
async def cleanup(self):
    """Async cleanup - preferred method"""
    if self._task and not self._task.done():
        self._task.cancel()
        try:
            await self._task
        except asyncio.CancelledError:
            pass
        self._task = None

shutdown()

Stop the background processing task and cleanup resources.

Source code in genlm/backend/trie/async_impl.py
def shutdown(self):
    """Stop the background processing task and cleanup resources."""
    if self._task is not None:
        try:
            self._task.cancel()
        except RuntimeError:
            # Ignore runtime errors that might occur if event loop is closed
            pass
        self._task = None

ParallelTokenCharacterTrie

Bases: TokenCharacterTrie

A GPU-optimized version of TokenCharacterTrie that performs weight sum and max operations in parallel.

Source code in genlm/backend/trie/parallel.py
class ParallelTokenCharacterTrie(TokenCharacterTrie):
    """A GPU-optimized version of `TokenCharacterTrie` that performs weight sum and max operations in parallel."""

    def __init__(self, decode, device=None, **kwargs):
        super().__init__(decode, **kwargs)

        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        if self.device not in ["cpu", "cuda"]:
            raise ValueError(f"Invalid device: {device}. Must be 'cpu', 'cuda' or None")

        self._build_reachability_matrix()
        self.token_ids = torch.tensor(
            self.token_id_to_leaf[:, 0], dtype=torch.long, device=self.device
        )

    def _build_parent_map(self):
        """Builds a mapping from each node to its parent node in the trie.

        Returns:
            (dict): A dictionary where keys are child nodes and values are their parent nodes.
        """
        parent = {}
        for node in range(len(self.children)):
            for child in self.jump[node]:
                parent[child] = node
        return parent

    def _build_reachability_matrix(self):
        """Constructs a sparse reachability matrix for efficient weight propagation.

        The matrix M is constructed such that M[i,j] = 1 if node j is either:
        - The leaf node i itself (self-connection)
        - An ancestor of leaf node i in the trie
        """
        leaf_indices = self.token_id_to_leaf[:, 1]
        parent = self._build_parent_map()

        rows, cols = [], []
        for i, node in enumerate(leaf_indices):
            # self connections
            rows.append(i)
            cols.append(node)

            current = node
            while current in parent:  # Walk up to root
                ancestor = parent[current]
                rows.append(i)
                cols.append(ancestor)
                current = ancestor

        self.src_indices = torch.tensor(rows, dtype=torch.long, device=self.device)
        self.dst_indices = torch.tensor(cols, dtype=torch.long, device=self.device)

        indices = torch.tensor([rows, cols], dtype=torch.long, device=self.device)
        values = torch.ones(len(rows), device=self.device)

        self.M = torch.sparse_coo_tensor(
            indices, values, (len(leaf_indices), len(self.children))
        ).to_sparse_csr()

    def _preprocess_ws(self, batch_ws):
        processed_batch_ws = []
        for ws in batch_ws:
            if not isinstance(ws, torch.Tensor):
                ws = torch.tensor(ws, device=self.device, dtype=torch.float32)
            elif ws.device != self.device or ws.dtype != torch.float32:
                ws = ws.to(device=self.device, dtype=torch.float32)
            assert ws.shape[0] == len(self.decode), [ws.shape[0], len(self.decode)]
            processed_batch_ws.append(ws)
        return torch.stack(processed_batch_ws)

    def weight_sum(self, ws):
        """Computes weight sums given token weights.

        For each node in the trie, this computes the sum of weights of all leaf nodes (tokens)
        that are descendants of that node. This is efficiently implemented using sparse matrix multiplication
        with a pre-computed reachability matrix.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).

        Returns:
            (numpy.ndarray): Summed weights for each node in the trie, shape (`len(self.decode)`,).
        """
        return self.batch_weight_sum(self._preprocess_ws([ws]))[0]

    def batch_weight_sum(self, ws):
        """Batch version of `weight_sum`.

        Args:
            ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).

        Returns:
            numpy.ndarray: Summed weights for each node in the trie, shape (batch_size × num_nodes).
        """
        ws = self._preprocess_ws(ws)
        masses = torch.sparse.mm(ws[:, self.token_ids], self.M)
        return masses.cpu().numpy()

    def weight_max(self, ws):
        """Computes the max weights given the token weights.

        For each node in the trie, this computes the maximum weight among all leaf nodes (tokens)
        that are descendants of that node. This is efficiently implemented using parallel scatter_reduce
        operations on GPU.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).

        Returns:
            (numpy.ndarray): Maximum weights for each node in the trie, shape (`len(self.decode)`,).
        """
        return self.batch_weight_max(self._preprocess_ws([ws]))[0]

    def batch_weight_max(self, ws):
        """Batch version of `weight_max`.

        Args:
            ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).

        Returns:
            (numpy.ndarray): Maximum weights for each node in the trie, shape (batch_size × num_nodes).
        """
        ws = self._preprocess_ws(ws)

        # Get leaf weights
        leaf_weights = ws[:, self.token_ids]  # shape: (batch_size × num_leafs)
        batch_size = leaf_weights.shape[0]

        # Use scatter_reduce to propagate maximum values in parallel
        result = torch.zeros((batch_size, len(self.children)), device=self.device)
        result.scatter_reduce_(
            dim=1,
            index=self.dst_indices.expand(batch_size, -1),
            src=leaf_weights[:, self.src_indices],
            reduce="amax",
            include_self=False,
        )

        return result.cpu().numpy()

weight_sum(ws)

Computes weight sums given token weights.

For each node in the trie, this computes the sum of weights of all leaf nodes (tokens) that are descendants of that node. This is efficiently implemented using sparse matrix multiplication with a pre-computed reachability matrix.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.decode),).

required

Returns:

Type Description
ndarray

Summed weights for each node in the trie, shape (len(self.decode),).

Source code in genlm/backend/trie/parallel.py
def weight_sum(self, ws):
    """Computes weight sums given token weights.

    For each node in the trie, this computes the sum of weights of all leaf nodes (tokens)
    that are descendants of that node. This is efficiently implemented using sparse matrix multiplication
    with a pre-computed reachability matrix.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).

    Returns:
        (numpy.ndarray): Summed weights for each node in the trie, shape (`len(self.decode)`,).
    """
    return self.batch_weight_sum(self._preprocess_ws([ws]))[0]

batch_weight_sum(ws)

Batch version of weight_sum.

Parameters:

Name Type Description Default
ws Tensor

Batch of token weights, shape (batch_size × len(self.decode)).

required

Returns:

Type Description

numpy.ndarray: Summed weights for each node in the trie, shape (batch_size × num_nodes).

Source code in genlm/backend/trie/parallel.py
def batch_weight_sum(self, ws):
    """Batch version of `weight_sum`.

    Args:
        ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).

    Returns:
        numpy.ndarray: Summed weights for each node in the trie, shape (batch_size × num_nodes).
    """
    ws = self._preprocess_ws(ws)
    masses = torch.sparse.mm(ws[:, self.token_ids], self.M)
    return masses.cpu().numpy()

weight_max(ws)

Computes the max weights given the token weights.

For each node in the trie, this computes the maximum weight among all leaf nodes (tokens) that are descendants of that node. This is efficiently implemented using parallel scatter_reduce operations on GPU.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.decode),).

required

Returns:

Type Description
ndarray

Maximum weights for each node in the trie, shape (len(self.decode),).

Source code in genlm/backend/trie/parallel.py
def weight_max(self, ws):
    """Computes the max weights given the token weights.

    For each node in the trie, this computes the maximum weight among all leaf nodes (tokens)
    that are descendants of that node. This is efficiently implemented using parallel scatter_reduce
    operations on GPU.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).

    Returns:
        (numpy.ndarray): Maximum weights for each node in the trie, shape (`len(self.decode)`,).
    """
    return self.batch_weight_max(self._preprocess_ws([ws]))[0]

batch_weight_max(ws)

Batch version of weight_max.

Parameters:

Name Type Description Default
ws Tensor

Batch of token weights, shape (batch_size × len(self.decode)).

required

Returns:

Type Description
ndarray

Maximum weights for each node in the trie, shape (batch_size × num_nodes).

Source code in genlm/backend/trie/parallel.py
def batch_weight_max(self, ws):
    """Batch version of `weight_max`.

    Args:
        ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).

    Returns:
        (numpy.ndarray): Maximum weights for each node in the trie, shape (batch_size × num_nodes).
    """
    ws = self._preprocess_ws(ws)

    # Get leaf weights
    leaf_weights = ws[:, self.token_ids]  # shape: (batch_size × num_leafs)
    batch_size = leaf_weights.shape[0]

    # Use scatter_reduce to propagate maximum values in parallel
    result = torch.zeros((batch_size, len(self.children)), device=self.device)
    result.scatter_reduce_(
        dim=1,
        index=self.dst_indices.expand(batch_size, -1),
        src=leaf_weights[:, self.src_indices],
        reduce="amax",
        include_self=False,
    )

    return result.cpu().numpy()