lm_state
genlm.bytes.byte_lm.lm_state
StatefulTokenizedLM
A stateful tokenized language model that maintains context and generates next token logprobs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
AsyncLM
|
The underlying language model |
required |
context
|
list
|
List of token IDs representing the current context |
required |
n_calls
|
int
|
Number of times the model has been called |
0
|
max_context_length
|
int
|
Maximum length of context to maintain |
None
|
Source code in genlm/bytes/byte_lm/lm_state.py
initial(model, initial_context=None, max_context_length=None)
classmethod
Creates an initial state for the language model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
AsyncLM
|
The language model to use |
required |
initial_context
|
list
|
Initial context of token IDs. Defaults to [tokenizer.bos_token_id] |
None
|
max_context_length
|
int
|
Maximum context length to maintain |
None
|
Returns:
Type | Description |
---|---|
StatefulTokenizedLM
|
A new instance with initial state |
Source code in genlm/bytes/byte_lm/lm_state.py
__lshift__(token)
Adds a new token to the context and returns a new state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
token
|
int
|
Token ID to add to context |
required |
Returns:
Type | Description |
---|---|
StatefulTokenizedLM
|
New state with updated context |
Source code in genlm/bytes/byte_lm/lm_state.py
logp_next()
async
Computes log probabilities for the next token given the current context.
Returns:
Type | Description |
---|---|
Tensor
|
Log probabilities for next tokens |
Source code in genlm/bytes/byte_lm/lm_state.py
StatefulByteLM
Bases: ABC
Abstract base class for byte-level language models with state.
Source code in genlm/bytes/byte_lm/lm_state.py
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
|
__lshift__(b)
abstractmethod
async
Adds a byte to the state and returns new state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
b
|
int
|
Byte to add |
required |
Returns:
Type | Description |
---|---|
StatefulByteLM
|
New state with added byte |
prune()
Prunes the current state if needed.
Override in subclasses.
Returns:
Type | Description |
---|---|
StatefulByteLM
|
Pruned state |
logp_next()
abstractmethod
async
Computes the log probability distribution for the next byte.
Returns:
Type | Description |
---|---|
LazyByteProbs
|
Log probabilities for next possible bytes |
prefill(bs)
async
Prefills the model state with a sequence of bytes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
bs
|
list[int]
|
Sequence of bytes to add to state |
required |
Returns:
Type | Description |
---|---|
StatefulByteLM
|
New state with all bytes added |
Source code in genlm/bytes/byte_lm/lm_state.py
greedy(context, steps)
async
Performs greedy decoding for given number of steps.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
context
|
bytes
|
Initial byte context |
required |
steps
|
int
|
Number of generation steps |
required |
Returns:
Type | Description |
---|---|
bytes
|
Generated byte sequence |
Source code in genlm/bytes/byte_lm/lm_state.py
sample(context, steps, draw=sample_dict)
async
Samples from the model for given number of steps.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
context
|
bytes
|
Initial byte context |
required |
steps
|
int
|
Number of generation steps |
required |
draw
|
Sampling function to use (default: sample_dict) |
sample_dict
|
Returns:
Type | Description |
---|---|
bytes
|
Generated byte sequence |