parallel
ParallelTokenCharacterTrie
Bases: TokenCharacterTrie
A GPU-optimized version of TokenCharacterTrie
that performs weight sum and max operations in parallel.
Source code in genlm/backend/trie/parallel.py
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
|
weight_sum(ws)
Computes weight sums given token weights.
For each node in the trie, this computes the sum of weights of all leaf nodes (tokens) that are descendants of that node. This is efficiently implemented using sparse matrix multiplication with a pre-computed reachability matrix.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ws
|
Tensor
|
Token weights, shape ( |
required |
Returns:
Type | Description |
---|---|
ndarray
|
Summed weights for each node in the trie, shape ( |
Source code in genlm/backend/trie/parallel.py
batch_weight_sum(ws)
Batch version of weight_sum
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ws
|
Tensor
|
Batch of token weights, shape (batch_size × |
required |
Returns:
Type | Description |
---|---|
numpy.ndarray: Summed weights for each node in the trie, shape (batch_size × num_nodes). |
Source code in genlm/backend/trie/parallel.py
weight_max(ws)
Computes the max weights given the token weights.
For each node in the trie, this computes the maximum weight among all leaf nodes (tokens) that are descendants of that node. This is efficiently implemented using parallel scatter_reduce operations on GPU.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ws
|
Tensor
|
Token weights, shape ( |
required |
Returns:
Type | Description |
---|---|
ndarray
|
Maximum weights for each node in the trie, shape ( |
Source code in genlm/backend/trie/parallel.py
batch_weight_max(ws)
Batch version of weight_max
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ws
|
Tensor
|
Batch of token weights, shape (batch_size × |
required |
Returns:
Type | Description |
---|---|
ndarray
|
Maximum weights for each node in the trie, shape (batch_size × num_nodes). |