sequence
SMC
This class implements sequential Monte Carlo (SMC) inference for controlled text generation. The generation process works as follows:
-
Token Sampling: At each step, the
unit_sampler
is used to extend each particle (candidate sequence) by sampling a new token. This grows all sequences by one token at a time. The sampler also outputs an importance weight with each extension to correct for the myopic nature of token-by-token sampling. -
Critic Evaluation: If a
critic
is provided, it scores the updated sequences (via it'sscore
method), reweighting the particles based on how well they satisfy the constraints encoded by the critic. -
Resampling: When the effective sample size (ESS) falls below the threshold, particles are resampled according to their weights. This helps focus computation on more promising sequences.
-
Termination: The process continues until either:
-
All sequences reach an end-of-sequence (EOS) token
-
The maximum token length is reached
-
If a critic is provided, the resulting sequences are properly weighted with respect to the product of the unit sampler's
target potential and the critic potential (unit_sampler.target * critic
). If a critic is not provided,
the resulting sequences are weighted with respect to the unit sampler's target potential.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
unit_sampler
|
TokenSampler
|
The sampler that generates tokens. |
required |
critic
|
Potential
|
A potential function that guides the generation process by scoring candidate sequences. Must have the same token type as the unit_sampler. |
None
|
Raises:
Type | Description |
---|---|
ValueError
|
If unit_sampler is not a TokenSampler, if critic is not a Potential, or if the token types of unit_sampler and critic don't match. |
Source code in genlm/control/sampler/sequence.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
|
__call__(n_particles, ess_threshold, max_tokens, verbosity=0, json_path=None, **kwargs)
async
Generate sequences using sequential Monte Carlo inference.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_particles
|
int
|
Number of particles (candidate sequences) to maintain during generation. Higher values provide better exploration but require more computation. |
required |
ess_threshold
|
float
|
Effective sample size threshold for resampling, expressed as a fraction of the number of particles. When ESS falls below this value, particles are resampled according to their weights. Should be between 0 and 1. Higher values lead to more frequent resampling. Note that when ess_threshold = 0, the critic is only applied at the end of the generation (if it is provided). |
required |
max_tokens
|
int
|
Maximum number of tokens to generate per sequence. Generation may terminate earlier if all sequences reach an EOS token. |
required |
verbosity
|
int
|
Verbosity level for the SMC algorithm. 0 is silent, 1 prints the particles at each step. Default is 0. |
0
|
json_path
|
str
|
JSON file path for saving a record of the inference run.
This can be used in conjunction with the |
None
|
**kwargs
|
dict
|
Additional keyword arguments to pass to the SMC algorithm.
See the |
{}
|
Returns:
Type | Description |
---|---|
Sequences
|
A container holding the generated sequences, their importance weights, and other metadata from the generation process. |
Source code in genlm/control/sampler/sequence.py
cleanup()
async
Clean up resources used by the inference engine.
This method should be called when the InferenceEngine is no longer needed.
Example
Source code in genlm/control/sampler/sequence.py
Sequences
dataclass
Container for sequence samples with their weights and probabilities.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
contexts
|
list
|
List of token sequences generated by the sampler. |
required |
log_weights
|
list
|
Log importance weights for each sequence. |
required |
Attributes:
Name | Type | Description |
---|---|---|
size |
int
|
Number of sequences in the container. |
logp |
float
|
Sum of log probabilities across all sequences. |
log_total |
float
|
Log of the sum of importance weights. |
log_ml |
float
|
Log marginal likelihood estimate. |
log_normalized_weights |
list
|
Log weights normalized to sum to 1. |
log_ess |
float
|
Log of the effective sample size. |
ess |
float
|
Effective sample size of the particle population. |
Source code in genlm/control/sampler/sequence.py
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 |
|
posterior
cached
property
Compute the estimated posterior distribution over sequences.
The probability of a sequence corresponds to its normalized weight. The probabilities of duplicate sequences are summed.
Returns:
Type | Description |
---|---|
chart
|
A normalized chart mapping sequences to their posterior probabilities, sorted in descending order by probability. |
decoded_posterior
cached
property
Compute posterior distribution over completed UTF-8 decodable sequences.
Filters for sequences that:
-
End with an EndOfSequence token
-
Can be decoded as UTF-8 strings
The probability of each sequence corresponds to its normalized weight among completed and decodable sequences. Probabilities of duplicate sequences (after decoding) are summed.
To obtain the posterior distribution over all byte sequences, use self.posterior
.
Returns:
Type | Description |
---|---|
chart
|
A normalized chart mapping decoded string sequences to their posterior probabilities, sorted in descending order by probability. Only includes sequences that meet both filtering criteria. |
normalized_weights
property
Return exponential of normalized log weights.