Skip to content

trie

genlm.bytes.trie

TrieMode

Bases: Enum

Modes for trie state behavior.

Source code in genlm/bytes/trie.py
class TrieMode(Enum):
    """Modes for trie state behavior."""

    WITHOUT_EOS = "without_eos"  # EOS tokens are treated as normal tokens
    WITH_EOS = "with_eos"  # EOS tokens get special handling (aggregated to EOS node)

TokenByteTrie

A trie data structure for efficient token-to-byte mapping.

Requires Token objects (from genlm.backend.tokenization) which allow handling models with duplicate byte strings (multiple token IDs mapping to the same bytes).

Source code in genlm/bytes/trie.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
class TokenByteTrie:
    """A trie data structure for efficient token-to-byte mapping.

    Requires Token objects (from genlm.backend.tokenization) which allow handling 
    models with duplicate byte strings (multiple token IDs mapping to the same bytes).
    """

    def __init__(
        self,
        decode,
        device=None,
        atomic_byte_strings=None,
        eot_sentinel=None,
        eos_byte_strings=None,
        max_batch_size=64,
    ):
        """Initialize a `TokenByteTrie`.

        Args:
            decode (list[Token]): List of Token objects representing the token vocabulary.
                Each Token must have both token_id and byte_string attributes.
            device (str, optional): Device to use for weight sum and max computations ('cpu' or 'cuda').
            atomic_byte_strings (list[bytes], optional): List of byte strings that should be treated as atomic units rather than being split into individual bytes.
            eot_sentinel (bytes|None, optional): End-of-token sentinel value. Default is None, which represents EOT as None.
            eos_byte_strings (set[bytes], optional): Set of tokens that should be treated as EOS (End of Sequence).
            max_batch_size (int, optional): Maximum batch size for weight sum sparse matrix multiplication.
        """
        if not decode:
            raise ValueError("decode cannot be empty")
        if Token.is_plain_bytes(decode[0]):
            warnings.warn(
                "Passing plain bytes to TokenByteTrie is deprecated. "
                "Use Token objects from decode_vocab() instead.",
                DeprecationWarning,
                stacklevel=2,
            )
            decode = [Token(token_id=i, byte_string=b) for i, b in enumerate(decode)]
        elif not isinstance(decode[0], Token):
            raise TypeError(
                f"decode must contain Token objects, got {type(decode[0]).__name__}. "
                f"Use genlm.backend.tokenization.decode_vocab() to get Token objects."
            )

        self.decode = decode
        self._byte_decode = [t.byte_string for t in decode]
        self.max_batch_size = max_batch_size

        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        if self.device not in ["cpu", "cuda"]:
            raise ValueError(f"Invalid device: {device}. Must be 'cpu', 'cuda' or None")

        self.eot_sentinel = eot_sentinel
        self.eos_byte_strings = set(eos_byte_strings or [])
        self.eos_token_ids = [
            token.token_id for token in self.decode 
            if token.byte_string in self.eos_byte_strings
        ]

        self._build_trie(atomic_byte_strings or [])
        self._renumber()
        self._build_node2prefix()
        self._build_reachability_matrix()
        self.token_ids = torch.tensor(
            self.token_id_to_leaf[:, 0], dtype=torch.long, device=self.device
        )

    def _build_trie(self, atomic_byte_strings):
        """Builds a trie data structure from the vocabulary.

        Handles duplicate byte strings by using (byte_string, token_id) as keys.
        Each token gets its own leaf node, even if multiple tokens share the same bytes.

        Returns:
            (dict): A dictionary where keys are token IDs and values are lists of characters.
        """
        # Check atomic_byte_strings against byte representations
        byte_set = set(self._byte_decode)
        for bs in atomic_byte_strings:
            if bs not in byte_set:
                raise ValueError(f"Atomic byte string {bs!r} not in vocabulary")

        # Check eos_byte_strings against byte representations
        for bs in self.eos_byte_strings:
            if bs not in byte_set:
                raise ValueError(f"EOS byte string {bs!r} not in vocabulary")

        self.word2leaf = {}
        self.children = [{}]  # First node is root
        self.root = 0
        self.token_id_to_leaf = []
        self.lookup = {}

        for token in self.decode:
            token_id = token.token_id
            word = token.byte_string

            # Use (word, token_id) as lookup key to allow duplicates
            lookup_key = (word, token_id)
            if lookup_key in self.lookup:  # pragma: no cover
                # This should never happen since Token objects have unique token_ids
                raise ValueError(f"Duplicate token in vocabulary: {token}, lookup_key: {lookup_key}")
            self.lookup[lookup_key] = token_id

            # Build ALL tokens in trie (including EOS tokens for conditioning mode)
            curr = self.root
            letters = [word] if word in atomic_byte_strings else word
            for letter in letters:
                if letter not in self.children[curr]:
                    self.children[curr][letter] = len(self.children)
                    self.children.append({})
                curr = self.children[curr][letter]

            # Each token gets its own leaf, using (eot_sentinel, token_id) as edge key
            # This allows multiple tokens with the same byte_string to have separate leaves
            leaf_edge_key = (self.eot_sentinel, token_id)
            self.children[curr][leaf_edge_key] = last = len(self.children)
            self.children.append({})

            # Use (word, token_id) as key in word2leaf to handle duplicates
            self.word2leaf[(word, token_id)] = last
            self.token_id_to_leaf.append((token_id, last))

        self.eos_node = len(self.children)
        self.children.append({})  # Create the EOS node
        self.children[self.root][EOS] = self.eos_node

        self.leaf2word = dict(zip(self.word2leaf.values(), self.word2leaf.keys()))
        self.jump = [
            np.array(sorted(x.values()), dtype=np.int32) for x in self.children
        ]

    def _renumber(self):
        """Renumber the states of the trie so that they are named by a contiguous
        range of integers and those integers respect the topological ordering
        of the trie. This improves the efficiency of the updating the trie as
        it improves memory locality.
        """
        self.ordering = np.array(list(self._order(self.root)), np.int32)
        ordering = {}
        for i, x in enumerate(self._order_full(self.root)):
            ordering[x] = i
        self._rename(f=lambda x: ordering[x])

    def _order(self, node):
        """Generate a topological ordering of nodes beneath the given node.

        Args:
            node (int): Starting node index

        Yields:
            int: Node indices in topological order
        """
        for a in self.children[node]:
            # Skip leaf edges (tuples like (eot_sentinel, token_id)) from ordering
            # but include all other edges including EOS (257)
            if isinstance(a, tuple):
                pass  # Skip leaf edges in ordering
            else:
                yield from self._order(self.children[node][a])
        yield node

    def _order_full(self, node):
        """Generate a complete topological ordering including all child nodes.

        Args:
            node (int): Starting node index

        Yields:
            (int): Node indices in complete topological order
        """
        for a in self.children[node]:
            yield from self._order_full(self.children[node][a])
        yield node

    def _rename(self, f):
        """Rename all node indices in the trie using the provided mapping function.

        Args:
            f (callable): Function that maps old node indices to new node indices
        """
        N = len(self.children)

        new_children = [{} for _ in range(N)]
        nodes = range(N)

        for x in nodes:
            for letter, y in self.children[x].items():
                new_children[f(x)][letter] = f(y)

        self.root = f(self.root)
        self.children = new_children
        self.word2leaf = {w: f(x) for w, x in self.word2leaf.items()}
        self.leaf2word = dict(zip(self.word2leaf.values(), self.word2leaf.keys()))

        self.token_id_to_leaf = np.array(
            [(i, f(x)) for i, x in self.token_id_to_leaf], dtype=np.int32
        )
        self.leaf2token_id = dict(
            zip(self.token_id_to_leaf[:, 1], self.token_id_to_leaf[:, 0])
        )

        self.ordering = np.array([f(x) for x in self.ordering])
        self.jump = [np.array(sorted(x.values()), dtype=np.int32) for x in new_children]

        # Update EOS node after renumbering
        self.eos_node = f(self.eos_node)

    def _build_node2prefix(self):
        """Builds a mapping from each node to its prefix.

        Returns:
            (dict): A dictionary where keys are node IDs and values are lists of characters.
        """
        node2prefix = {self.root: []}
        for x in reversed(range(len(self.children))):
            for letter, y in self.children[x].items():
                # Handle leaf edges: (eot_sentinel, token_id) tuples
                if isinstance(letter, tuple):
                    # This is a leaf edge, prefix stays the same
                    node2prefix[y] = node2prefix[x]
                elif isinstance(letter, bytes):
                    node2prefix[y] = node2prefix[x] + list(letter)
                else:
                    # Regular byte transition (int)
                    node2prefix[y] = node2prefix[x] + [letter]

        self.node2prefix = node2prefix

    def _build_parent_map(self):
        """Builds a mapping from each node to its parent node in the trie.

        Returns:
            (dict): A dictionary where keys are child nodes and values are their parent nodes.
        """
        parent = {}
        for node in range(len(self.children)):
            for child in self.jump[node]:
                parent[child] = node
        return parent

    def _build_reachability_matrix(self):
        """Constructs dual sparse reachability matrices for efficient weight propagation.

        The matrix M is constructed such that M[i,j] = 1 if node j is either:
        - The leaf node i itself (self-connection)
        - An ancestor of leaf node i in the trie

        For propagate_eos mode, EOS tokens contribute directly to eos_node and root.
        """
        leaf_indices = self.token_id_to_leaf[:, 1]
        parent = self._build_parent_map()
        # Build no_eos matrix (includes all tokens, doesn't map any tokens to the eos_node)
        rows_no_eos, cols_no_eos = [], []
        # Build with_eos matrix (maps EOS tokens to the eos_node only)
        rows_with_eos, cols_with_eos = [], []

        for i, node in enumerate(leaf_indices):
            token_id = self.token_id_to_leaf[i, 0]
            token = self.decode[token_id]
            token_bytes = token.byte_string

            # self-connection
            rows_no_eos.append(i)
            cols_no_eos.append(node)
            if token_bytes not in self.eos_byte_strings:
                rows_with_eos.append(i)
                cols_with_eos.append(node)
            else:
                # EOS tokens: contribute directly to eos_node and root
                rows_with_eos.append(i)
                cols_with_eos.append(self.eos_node)
                rows_with_eos.append(i)
                cols_with_eos.append(self.root)

            current = node
            while current in parent:
                ancestor = parent[current]
                rows_no_eos.append(i)
                cols_no_eos.append(ancestor)
                if token_bytes not in self.eos_byte_strings:
                    rows_with_eos.append(i)
                    cols_with_eos.append(ancestor)
                current = ancestor

        # Build without_eos matrix
        indices_no_eos = torch.tensor(
            [rows_no_eos, cols_no_eos], dtype=torch.long, device=self.device
        )
        values_no_eos = torch.ones(len(rows_no_eos), device=self.device)
        self.M_no_eos = torch.sparse_coo_tensor(
            indices_no_eos, values_no_eos, (len(leaf_indices), len(self.children))
        ).to_sparse_csr()

        # Build with_eos matrix
        indices_with_eos = torch.tensor(
            [rows_with_eos, cols_with_eos], dtype=torch.long, device=self.device
        )
        values_with_eos = torch.ones(len(rows_with_eos), device=self.device)
        self.M_with_eos = torch.sparse_coo_tensor(
            indices_with_eos, values_with_eos, (len(leaf_indices), len(self.children))
        ).to_sparse_csr()

        # Keep the old matrix for backward compatibility
        self.M = self.M_no_eos
        self.src_indices = torch.tensor(
            rows_no_eos, dtype=torch.long, device=self.device
        )
        self.dst_indices = torch.tensor(
            cols_no_eos, dtype=torch.long, device=self.device
        )

    def _preprocess_ws(self, batch_ws):
        """Preprocess weight sums for batch processing.

        Args:
            batch_ws (list|np.ndarray|torch.Tensor): List of weight sum tensors or lists of weight sums.

        Returns:
            (torch.Tensor): Stacked weight sum tensor.
        """
        processed_batch_ws = []
        for ws in batch_ws:
            if not isinstance(ws, torch.Tensor):
                ws = torch.tensor(ws, device=self.device, dtype=torch.float32)
            elif ws.device != self.device or ws.dtype != torch.float32:
                ws = ws.to(device=self.device, dtype=torch.float32)
            assert ws.shape[0] == len(self.decode), [ws.shape[0], len(self.decode)]
            processed_batch_ws.append(ws)
        return torch.stack(processed_batch_ws)

    def weight_sum(self, ws, mode=None):
        """Computes the sum of weights of all leaf nodes (tokens) that are descendants of each node in the trie.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).
            mode (TrieMode, optional): Trie mode - determines matrix selection.
                                     If None, defaults to WITHOUT_EOS.

        Returns:
            (numpy.ndarray): Summed weights for each node in the trie, shape (num_nodes,).
        """
        mode = mode or TrieMode.WITHOUT_EOS
        return self.batch_weight_sum(self._preprocess_ws([ws]), mode=mode)[0]

    def batch_weight_sum(self, ws, mode=None):
        """Batch version of `weight_sum`.

        Args:
            ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).
            mode (TrieMode, optional): Trie mode - determines matrix selection.
                                     If None, defaults to WITHOUT_EOS.

        Returns:
            (numpy.ndarray): Summed weights for each node in the trie, shape (batch_size × num_nodes).
        """
        mode = mode or TrieMode.WITHOUT_EOS

        ws = self._preprocess_ws(ws)
        batch_size = ws.shape[0]
        all_masses = []

        # Choose matrix based on mode
        matrix = self.M_with_eos if mode == TrieMode.WITH_EOS else self.M_no_eos

        # If you are getting illegal memory access errors here,
        # try reducing the max_batch_size.
        for i in range(0, batch_size, self.max_batch_size):
            batch_ws = ws[i : i + self.max_batch_size]
            masses = torch.sparse.mm(batch_ws[:, self.token_ids], matrix)
            all_masses.append(masses)
        return torch.cat(all_masses, dim=0)

    def weight_max(self, ws):
        """Computes the maximum weight of all descendant leaf nodes (tokens) for each node in the trie.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).

        Returns:
            (numpy.ndarray): Maximum weights for each node in the trie, shape (num_nodes,).
        """
        return self.batch_weight_max(self._preprocess_ws([ws]))[0]

    def batch_weight_max(self, ws):
        """Batch version of `weight_max`.

        Args:
            ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).

        Returns:
            (numpy.ndarray): Maximum weights for each node in the trie, shape (batch_size × num_nodes).
        """
        ws = self._preprocess_ws(ws)

        # Get leaf weights
        leaf_weights = ws[:, self.token_ids]  # shape: (batch_size × num_leafs)
        batch_size = leaf_weights.shape[0]

        # Use scatter_reduce to propagate maximum values in parallel
        result = torch.zeros((batch_size, len(self.children)), device=self.device)
        result.scatter_reduce_(
            dim=1,
            index=self.dst_indices.expand(batch_size, -1),
            src=leaf_weights[:, self.src_indices],
            reduce="amax",
            include_self=False,
        )

        return result

    def visualize(self, ws=None):
        """Visualize the trie structure using Graphviz.

        Args:
            ws (np.ndarray|None): Optional weight vector to display at each node. Should be of length `len(self.children)`.

        Returns:
            (graphviz.Digraph): The generated graph object
        """
        try:
            import graphviz
        except ImportError:  # pragma: no cover
            raise ImportError(
                "Please install graphviz: pip install graphviz"
            )  # pragma: no cover

        if ws is not None and len(ws) != len(self.children):
            raise ValueError(
                f"Weight vector length ({len(ws)}) must match number of nodes ({len(self.children)})"
            )

        dot = graphviz.Digraph(comment="Token Character Trie")
        dot.attr(rankdir="LR")

        # Create a subgraph for the legend
        with dot.subgraph(name="cluster_legend") as legend:
            legend.attr(label="Legend", fontsize="10")
            legend.attr("node", fontsize="7", width="0.1", height="0.1")

            # Example internal node
            legend.node(
                "legend_internal",
                "Internal Node ID\n'Prefix'\nWeight (if provided)",
                shape="circle",
            )

            # Example leaf node
            legend.node("legend_leaf", "Complete Token", shape="doublecircle")

            legend.edge(
                "legend_internal",
                "legend_leaf",
                label="Token item",
                fontsize="10",
            )

            # Align legend horizontally
            legend.attr(rankdir="TB")
            legend.attr(rank="same")

        # Add the main trie nodes and edges
        for node_id in range(len(self.children)):
            prefix = self.node2prefix[node_id]

            if ws is not None:
                label = f"{node_id}\n'{prefix}'\n{ws[node_id]:.4f}"
            else:
                label = f"{node_id}\n'{prefix}'"

            # Color nodes based on mass if provided
            if ws is not None:
                max_ws = ws.max()
                if max_ws > 0:
                    intensity = int(255 * (1 - ws[node_id] / max_ws))
                    color = f"#{intensity:02x}{255:02x}{intensity:02x}"
                else:
                    color = "#ffffff"  # white for zero mass
            else:
                color = "#ffffff"  # default white

            if node_id in self.leaf2word:
                dot.node(
                    str(node_id),
                    label,
                    shape="doublecircle",
                    style="filled",
                    fillcolor=color,
                )
            else:
                dot.node(
                    str(node_id), label, shape="circle", style="filled", fillcolor=color
                )

        for node_id, children in enumerate(self.children):
            for char, child_id in children.items():
                # Handle leaf edges: (eot_sentinel, token_id) tuples
                if isinstance(char, tuple):
                    _, token_id = char
                    edge_label = f"EOT (ID: {token_id})"
                else:
                    # Regular byte transition (int) or EOS
                    edge_label = str(char)

                dot.edge(str(node_id), str(child_id), label=edge_label)

        return dot

__init__(decode, device=None, atomic_byte_strings=None, eot_sentinel=None, eos_byte_strings=None, max_batch_size=64)

Initialize a TokenByteTrie.

Parameters:

Name Type Description Default
decode list[Token]

List of Token objects representing the token vocabulary. Each Token must have both token_id and byte_string attributes.

required
device str

Device to use for weight sum and max computations ('cpu' or 'cuda').

None
atomic_byte_strings list[bytes]

List of byte strings that should be treated as atomic units rather than being split into individual bytes.

None
eot_sentinel bytes | None

End-of-token sentinel value. Default is None, which represents EOT as None.

None
eos_byte_strings set[bytes]

Set of tokens that should be treated as EOS (End of Sequence).

None
max_batch_size int

Maximum batch size for weight sum sparse matrix multiplication.

64
Source code in genlm/bytes/trie.py
def __init__(
    self,
    decode,
    device=None,
    atomic_byte_strings=None,
    eot_sentinel=None,
    eos_byte_strings=None,
    max_batch_size=64,
):
    """Initialize a `TokenByteTrie`.

    Args:
        decode (list[Token]): List of Token objects representing the token vocabulary.
            Each Token must have both token_id and byte_string attributes.
        device (str, optional): Device to use for weight sum and max computations ('cpu' or 'cuda').
        atomic_byte_strings (list[bytes], optional): List of byte strings that should be treated as atomic units rather than being split into individual bytes.
        eot_sentinel (bytes|None, optional): End-of-token sentinel value. Default is None, which represents EOT as None.
        eos_byte_strings (set[bytes], optional): Set of tokens that should be treated as EOS (End of Sequence).
        max_batch_size (int, optional): Maximum batch size for weight sum sparse matrix multiplication.
    """
    if not decode:
        raise ValueError("decode cannot be empty")
    if Token.is_plain_bytes(decode[0]):
        warnings.warn(
            "Passing plain bytes to TokenByteTrie is deprecated. "
            "Use Token objects from decode_vocab() instead.",
            DeprecationWarning,
            stacklevel=2,
        )
        decode = [Token(token_id=i, byte_string=b) for i, b in enumerate(decode)]
    elif not isinstance(decode[0], Token):
        raise TypeError(
            f"decode must contain Token objects, got {type(decode[0]).__name__}. "
            f"Use genlm.backend.tokenization.decode_vocab() to get Token objects."
        )

    self.decode = decode
    self._byte_decode = [t.byte_string for t in decode]
    self.max_batch_size = max_batch_size

    self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    if self.device not in ["cpu", "cuda"]:
        raise ValueError(f"Invalid device: {device}. Must be 'cpu', 'cuda' or None")

    self.eot_sentinel = eot_sentinel
    self.eos_byte_strings = set(eos_byte_strings or [])
    self.eos_token_ids = [
        token.token_id for token in self.decode 
        if token.byte_string in self.eos_byte_strings
    ]

    self._build_trie(atomic_byte_strings or [])
    self._renumber()
    self._build_node2prefix()
    self._build_reachability_matrix()
    self.token_ids = torch.tensor(
        self.token_id_to_leaf[:, 0], dtype=torch.long, device=self.device
    )

weight_sum(ws, mode=None)

Computes the sum of weights of all leaf nodes (tokens) that are descendants of each node in the trie.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.decode),).

required
mode TrieMode

Trie mode - determines matrix selection. If None, defaults to WITHOUT_EOS.

None

Returns:

Type Description
ndarray

Summed weights for each node in the trie, shape (num_nodes,).

Source code in genlm/bytes/trie.py
def weight_sum(self, ws, mode=None):
    """Computes the sum of weights of all leaf nodes (tokens) that are descendants of each node in the trie.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).
        mode (TrieMode, optional): Trie mode - determines matrix selection.
                                 If None, defaults to WITHOUT_EOS.

    Returns:
        (numpy.ndarray): Summed weights for each node in the trie, shape (num_nodes,).
    """
    mode = mode or TrieMode.WITHOUT_EOS
    return self.batch_weight_sum(self._preprocess_ws([ws]), mode=mode)[0]

batch_weight_sum(ws, mode=None)

Batch version of weight_sum.

Parameters:

Name Type Description Default
ws Tensor

Batch of token weights, shape (batch_size × len(self.decode)).

required
mode TrieMode

Trie mode - determines matrix selection. If None, defaults to WITHOUT_EOS.

None

Returns:

Type Description
ndarray

Summed weights for each node in the trie, shape (batch_size × num_nodes).

Source code in genlm/bytes/trie.py
def batch_weight_sum(self, ws, mode=None):
    """Batch version of `weight_sum`.

    Args:
        ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).
        mode (TrieMode, optional): Trie mode - determines matrix selection.
                                 If None, defaults to WITHOUT_EOS.

    Returns:
        (numpy.ndarray): Summed weights for each node in the trie, shape (batch_size × num_nodes).
    """
    mode = mode or TrieMode.WITHOUT_EOS

    ws = self._preprocess_ws(ws)
    batch_size = ws.shape[0]
    all_masses = []

    # Choose matrix based on mode
    matrix = self.M_with_eos if mode == TrieMode.WITH_EOS else self.M_no_eos

    # If you are getting illegal memory access errors here,
    # try reducing the max_batch_size.
    for i in range(0, batch_size, self.max_batch_size):
        batch_ws = ws[i : i + self.max_batch_size]
        masses = torch.sparse.mm(batch_ws[:, self.token_ids], matrix)
        all_masses.append(masses)
    return torch.cat(all_masses, dim=0)

weight_max(ws)

Computes the maximum weight of all descendant leaf nodes (tokens) for each node in the trie.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.decode),).

required

Returns:

Type Description
ndarray

Maximum weights for each node in the trie, shape (num_nodes,).

Source code in genlm/bytes/trie.py
def weight_max(self, ws):
    """Computes the maximum weight of all descendant leaf nodes (tokens) for each node in the trie.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.decode)`,).

    Returns:
        (numpy.ndarray): Maximum weights for each node in the trie, shape (num_nodes,).
    """
    return self.batch_weight_max(self._preprocess_ws([ws]))[0]

batch_weight_max(ws)

Batch version of weight_max.

Parameters:

Name Type Description Default
ws Tensor

Batch of token weights, shape (batch_size × len(self.decode)).

required

Returns:

Type Description
ndarray

Maximum weights for each node in the trie, shape (batch_size × num_nodes).

Source code in genlm/bytes/trie.py
def batch_weight_max(self, ws):
    """Batch version of `weight_max`.

    Args:
        ws (torch.Tensor): Batch of token weights, shape (batch_size × `len(self.decode)`).

    Returns:
        (numpy.ndarray): Maximum weights for each node in the trie, shape (batch_size × num_nodes).
    """
    ws = self._preprocess_ws(ws)

    # Get leaf weights
    leaf_weights = ws[:, self.token_ids]  # shape: (batch_size × num_leafs)
    batch_size = leaf_weights.shape[0]

    # Use scatter_reduce to propagate maximum values in parallel
    result = torch.zeros((batch_size, len(self.children)), device=self.device)
    result.scatter_reduce_(
        dim=1,
        index=self.dst_indices.expand(batch_size, -1),
        src=leaf_weights[:, self.src_indices],
        reduce="amax",
        include_self=False,
    )

    return result

visualize(ws=None)

Visualize the trie structure using Graphviz.

Parameters:

Name Type Description Default
ws ndarray | None

Optional weight vector to display at each node. Should be of length len(self.children).

None

Returns:

Type Description
Digraph

The generated graph object

Source code in genlm/bytes/trie.py
def visualize(self, ws=None):
    """Visualize the trie structure using Graphviz.

    Args:
        ws (np.ndarray|None): Optional weight vector to display at each node. Should be of length `len(self.children)`.

    Returns:
        (graphviz.Digraph): The generated graph object
    """
    try:
        import graphviz
    except ImportError:  # pragma: no cover
        raise ImportError(
            "Please install graphviz: pip install graphviz"
        )  # pragma: no cover

    if ws is not None and len(ws) != len(self.children):
        raise ValueError(
            f"Weight vector length ({len(ws)}) must match number of nodes ({len(self.children)})"
        )

    dot = graphviz.Digraph(comment="Token Character Trie")
    dot.attr(rankdir="LR")

    # Create a subgraph for the legend
    with dot.subgraph(name="cluster_legend") as legend:
        legend.attr(label="Legend", fontsize="10")
        legend.attr("node", fontsize="7", width="0.1", height="0.1")

        # Example internal node
        legend.node(
            "legend_internal",
            "Internal Node ID\n'Prefix'\nWeight (if provided)",
            shape="circle",
        )

        # Example leaf node
        legend.node("legend_leaf", "Complete Token", shape="doublecircle")

        legend.edge(
            "legend_internal",
            "legend_leaf",
            label="Token item",
            fontsize="10",
        )

        # Align legend horizontally
        legend.attr(rankdir="TB")
        legend.attr(rank="same")

    # Add the main trie nodes and edges
    for node_id in range(len(self.children)):
        prefix = self.node2prefix[node_id]

        if ws is not None:
            label = f"{node_id}\n'{prefix}'\n{ws[node_id]:.4f}"
        else:
            label = f"{node_id}\n'{prefix}'"

        # Color nodes based on mass if provided
        if ws is not None:
            max_ws = ws.max()
            if max_ws > 0:
                intensity = int(255 * (1 - ws[node_id] / max_ws))
                color = f"#{intensity:02x}{255:02x}{intensity:02x}"
            else:
                color = "#ffffff"  # white for zero mass
        else:
            color = "#ffffff"  # default white

        if node_id in self.leaf2word:
            dot.node(
                str(node_id),
                label,
                shape="doublecircle",
                style="filled",
                fillcolor=color,
            )
        else:
            dot.node(
                str(node_id), label, shape="circle", style="filled", fillcolor=color
            )

    for node_id, children in enumerate(self.children):
        for char, child_id in children.items():
            # Handle leaf edges: (eot_sentinel, token_id) tuples
            if isinstance(char, tuple):
                _, token_id = char
                edge_label = f"EOT (ID: {token_id})"
            else:
                # Regular byte transition (int) or EOS
                edge_label = str(char)

            dot.edge(str(node_id), str(child_id), label=edge_label)

    return dot

TrieOp

Bases: Enum

Enumeration of supported trie operations.

Source code in genlm/bytes/trie.py
class TrieOp(Enum):
    """Enumeration of supported trie operations."""

    SUM = "sum"
    MAX = "max"

AsyncTokenByteTrie

An asynchronous wrapper for TokenByteTrie implementations that provides automatic request batching.

Source code in genlm/bytes/trie.py
class AsyncTokenByteTrie:
    """An asynchronous wrapper for TokenByteTrie implementations that provides automatic request batching."""

    def __init__(self, trie):
        """Initialize an `AsyncTokenByteTrie`.

        Args:
            trie (TokenByteTrie): The underlying `TokenByteTrie` instance
        """
        self.trie = trie
        self._queue = None
        self._task = None

    @classmethod
    def from_vocab(cls, vocab, **kwargs):
        """Creates an `AsyncTokenByteTrie` from a vocabulary.

        Args:
            vocab (list[Token]): List of Token objects representing the vocabulary.
                Use genlm.backend.tokenization.decode_vocab() to get Token objects from a tokenizer.
            **kwargs (dict): Additional arguments passed to the trie constructor.
                             Can include 'eos_byte_strings' for EOS support.

        Returns:
            (AsyncTokenByteTrie): The initialized asynchronous trie instance.
        """
        trie = TokenByteTrie(decode=vocab, **kwargs)
        return cls(trie)

    def _queue_request(self, ws, mode, op):
        if not self._task or self._task.done():
            self.start()

        future = asyncio.get_running_loop().create_future()
        self._queue.put_nowait(((ws, mode), future, op))
        return future

    async def weight_sum(self, ws, mode=None):
        """Queue a `weight_sum` request. Multiple concurrent calls will be automatically batched
        together by (operation, mode) pairs.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).
            mode (TrieMode, optional): Trie mode determining EOS handling. Defaults to WITHOUT_EOS.

        Returns:
            (np.ndarray): The calculated mass sums for the given distribution.
        """
        mode = mode or TrieMode.WITHOUT_EOS
        return await self._queue_request(ws, mode, TrieOp.SUM)

    async def weight_max(self, ws):
        """Queue a `weight_max` request. Multiple concurrent calls will be automatically batched
        together.

        Args:
            ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).

        Returns:
            (np.ndarray): The calculated max weights for the given distribution.
        """
        # For MAX, mode doesn't matter so use WITHOUT_EOS as default
        return await self._queue_request(ws, TrieMode.WITHOUT_EOS, TrieOp.MAX)

    def start(self):
        """Start the background processing task if not already running."""
        if not self._task or self._task.done():
            logger.debug("starting background loop")
            # Create a new queue so that it is bound to the current event loop
            self._queue = asyncio.Queue()
            self._task = asyncio.create_task(self._background_loop())

    async def _background_loop(self):
        """Background task that processes queued weight sum and max requests.

        Continuously monitors the queue for new requests and processes them in batches
        grouped by (operation, mode) pairs using the underlying trie implementation.

        Raises:
            (Exception): If any error occurs during processing, it is propagated to all
                         pending futures in the current batch.
        """
        while True:
            try:
                # Group by (operation, mode) pairs for efficient batching
                op_mode_groups = defaultdict(list)

                (ws, mode), future, op = await self._queue.get()
                op_mode_groups[(op, mode)].append(((ws, mode), future))

                try:
                    while True:
                        (ws, mode), future, op = self._queue.get_nowait()
                        op_mode_groups[(op, mode)].append(((ws, mode), future))
                except asyncio.QueueEmpty:
                    pass

                for (op, mode), group in op_mode_groups.items():
                    requests, futures = zip(*group)
                    # Extract just the ws tensors from the (ws, mode) tuples
                    ws_list = [req[0] for req in requests]

                    if op == TrieOp.SUM:
                        if logger.isEnabledFor(logging.DEBUG):
                            logger.debug(
                                f"processing {len(ws_list)} sum requests with mode {mode}"
                            )  # pragma: no cover
                        results = self.trie.batch_weight_sum(ws_list, mode=mode)
                    elif op == TrieOp.MAX:
                        if logger.isEnabledFor(logging.DEBUG):
                            logger.debug(
                                f"processing {len(ws_list)} max requests"
                            )  # pragma: no cover
                        # MAX operations don't need mode, so use the original batch_weight_max
                        results = self.trie.batch_weight_max(ws_list)
                    else:  # pragma: no cover
                        raise ValueError(f"Unknown trie operation: {op}")

                    for future, result in zip(futures, results):
                        future.set_result(result)

            except Exception as e:
                for group in op_mode_groups.values():
                    for _, future in group:
                        if not future.done():
                            future.set_exception(e)
                raise

    async def cleanup(self):
        """Async cleanup - preferred method"""
        if self._task and not self._task.done():
            self._task.cancel()
            try:
                await self._task
            except asyncio.CancelledError:
                pass
            self._task = None

    def shutdown(self):
        """Stop the background processing task and cleanup resources."""
        if self._task is not None:
            try:
                self._task.cancel()
            except RuntimeError:  # pragma: no cover
                # Ignore runtime errors that might occur if event loop is closed
                pass
            self._task = None

    def __del__(self):
        self.shutdown()

__init__(trie)

Initialize an AsyncTokenByteTrie.

Parameters:

Name Type Description Default
trie TokenByteTrie

The underlying TokenByteTrie instance

required
Source code in genlm/bytes/trie.py
def __init__(self, trie):
    """Initialize an `AsyncTokenByteTrie`.

    Args:
        trie (TokenByteTrie): The underlying `TokenByteTrie` instance
    """
    self.trie = trie
    self._queue = None
    self._task = None

from_vocab(vocab, **kwargs) classmethod

Creates an AsyncTokenByteTrie from a vocabulary.

Parameters:

Name Type Description Default
vocab list[Token]

List of Token objects representing the vocabulary. Use genlm.backend.tokenization.decode_vocab() to get Token objects from a tokenizer.

required
**kwargs dict

Additional arguments passed to the trie constructor. Can include 'eos_byte_strings' for EOS support.

{}

Returns:

Type Description
AsyncTokenByteTrie

The initialized asynchronous trie instance.

Source code in genlm/bytes/trie.py
@classmethod
def from_vocab(cls, vocab, **kwargs):
    """Creates an `AsyncTokenByteTrie` from a vocabulary.

    Args:
        vocab (list[Token]): List of Token objects representing the vocabulary.
            Use genlm.backend.tokenization.decode_vocab() to get Token objects from a tokenizer.
        **kwargs (dict): Additional arguments passed to the trie constructor.
                         Can include 'eos_byte_strings' for EOS support.

    Returns:
        (AsyncTokenByteTrie): The initialized asynchronous trie instance.
    """
    trie = TokenByteTrie(decode=vocab, **kwargs)
    return cls(trie)

weight_sum(ws, mode=None) async

Queue a weight_sum request. Multiple concurrent calls will be automatically batched together by (operation, mode) pairs.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.trie.decode),).

required
mode TrieMode

Trie mode determining EOS handling. Defaults to WITHOUT_EOS.

None

Returns:

Type Description
ndarray

The calculated mass sums for the given distribution.

Source code in genlm/bytes/trie.py
async def weight_sum(self, ws, mode=None):
    """Queue a `weight_sum` request. Multiple concurrent calls will be automatically batched
    together by (operation, mode) pairs.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).
        mode (TrieMode, optional): Trie mode determining EOS handling. Defaults to WITHOUT_EOS.

    Returns:
        (np.ndarray): The calculated mass sums for the given distribution.
    """
    mode = mode or TrieMode.WITHOUT_EOS
    return await self._queue_request(ws, mode, TrieOp.SUM)

weight_max(ws) async

Queue a weight_max request. Multiple concurrent calls will be automatically batched together.

Parameters:

Name Type Description Default
ws Tensor

Token weights, shape (len(self.trie.decode),).

required

Returns:

Type Description
ndarray

The calculated max weights for the given distribution.

Source code in genlm/bytes/trie.py
async def weight_max(self, ws):
    """Queue a `weight_max` request. Multiple concurrent calls will be automatically batched
    together.

    Args:
        ws (torch.Tensor): Token weights, shape (`len(self.trie.decode)`,).

    Returns:
        (np.ndarray): The calculated max weights for the given distribution.
    """
    # For MAX, mode doesn't matter so use WITHOUT_EOS as default
    return await self._queue_request(ws, TrieMode.WITHOUT_EOS, TrieOp.MAX)

start()

Start the background processing task if not already running.

Source code in genlm/bytes/trie.py
def start(self):
    """Start the background processing task if not already running."""
    if not self._task or self._task.done():
        logger.debug("starting background loop")
        # Create a new queue so that it is bound to the current event loop
        self._queue = asyncio.Queue()
        self._task = asyncio.create_task(self._background_loop())

cleanup() async

Async cleanup - preferred method

Source code in genlm/bytes/trie.py
async def cleanup(self):
    """Async cleanup - preferred method"""
    if self._task and not self._task.done():
        self._task.cancel()
        try:
            await self._task
        except asyncio.CancelledError:
            pass
        self._task = None

shutdown()

Stop the background processing task and cleanup resources.

Source code in genlm/bytes/trie.py
def shutdown(self):
    """Stop the background processing task and cleanup resources."""
    if self._task is not None:
        try:
            self._task.cancel()
        except RuntimeError:  # pragma: no cover
            # Ignore runtime errors that might occur if event loop is closed
            pass
        self._task = None