Basic usage¶
This example shows how to use the genlm-bytes
library for byte-level language modeling.
from genlm.bytes import ByteBeamState, BeamParams
from genlm.backend import load_model_by_name
/opt/homebrew/Caskroom/miniconda/base/envs/genlm-tokenization/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
First, load a token-level language model from a huggingface model name. Dependeing on whether CUDA is available, the model will be loaded using either a huggingface (CPU) or vllm (GPU) backend.
llm = load_model_by_name("gpt2-medium")
/opt/homebrew/Caskroom/miniconda/base/envs/genlm-tokenization/lib/python3.11/site-packages/genlm/backend/tokenization/vocab.py:98: UserWarning: Duplicate tokens found in string vocabulary. This may lead to downstream issues with the string vocabulary; we recommend using the byte vocabulary. warnings.warn(
Initialize a beam state with a maximum beam width of 5.
beam = await ByteBeamState.initial(llm, BeamParams(K=5))
/Users/benlebrun/new-genlm/genlm-tokenization/genlm/bytes/trie.py:208: UserWarning: Sparse CSR tensor support is in beta state. If you miss a functionality in the sparse tensor support, please submit a feature request to https://github.com/pytorch/pytorch/issues. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/SparseCsrTensorImpl.cpp:55.) ).to_sparse_csr()
Populate the beam state with the context. The return value is a new beam state.
beam = await beam.prefill(b"An apple a day keeps the ")
beam
Z: -19.598907929485275 Candidates: (1.0000) -19.60: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣ (0.0000) -31.03: <|endoftext|>|An|␣apple|␣a|␣day|␣keep|s|␣the|␣ (0.0000) -36.22: <|endoftext|>|An|␣app|le|␣a|␣day|␣keeps|␣the|␣ (0.0000) -36.49: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣th|e|␣ (0.0000) -40.52: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣t|he|␣
- Each candidate in the beam corresponds to a sequence of tokens (in purple) and a partial token (in green).
- Each candidate has an associated log weight (the negative numbers in grey), which is the log probability of the sequence of tokens and the partial token.
- The
Z
value corresponds to our estimate of the log partition function, which is the estimate of the prefix probability of the context under the language model. - Each candidate also has an associated probability (shown on the left in green), which is the weight normalized by the partition function.
We can use the logp_next
method to get the (log) probability distribution over the next byte.
# Get the log probability distribution over the next byte.
logp_next = await beam.logp_next()
logp_next.pretty().top(5) # Show the top 5 most probable next bytes
key | value |
b'd' | -0.5768002911707057 |
b'b' | -2.8733914084455527 |
b's' | -2.981722712805219 |
b'w' | -3.375940367664043 |
b'm' | -3.5282914648667756 |
To advance the beam by the next byte, we first prune it to keep only the top 5 candidates, and then use the <<
operator to feed in the next byte.
new_beam = await (beam.prune() << 100) # 100 is the byte value of 'd'
new_beam
Z: -20.17567801749765 Candidates: (1.0000) -20.18: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣d (0.0000) -31.93: <|endoftext|>|An|␣apple|␣a|␣day|␣keep|s|␣the|␣d (0.0000) -38.71: <|endoftext|>|An|␣app|le|␣a|␣day|␣keeps|␣the|␣d (0.0000) -39.28: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣th|e|␣d (0.0000) -40.25: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣|d (0.0000) -43.16: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣t|he|␣d (0.0000) -51.34: <|endoftext|>|An|␣apple|␣a|␣day|␣keep|s|␣the|␣|d (0.0000) -54.79: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣th|e|␣|d (0.0000) -56.64: <|endoftext|>|An|␣app|le|␣a|␣day|␣keeps|␣the|␣|d (0.0000) -58.94: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣t|he|␣|d
Since extending the beam by one byte can grow the number of candidates, we can again prune it to keep only the top 5 candidates:
pruned_beam = new_beam.prune()
pruned_beam
Z: -20.175678017602173 Candidates: (1.0000) -20.18: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣d (0.0000) -31.93: <|endoftext|>|An|␣apple|␣a|␣day|␣keep|s|␣the|␣d (0.0000) -38.71: <|endoftext|>|An|␣app|le|␣a|␣day|␣keeps|␣the|␣d (0.0000) -39.28: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣th|e|␣d (0.0000) -40.25: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣|d
We can further speed up the algorithm with more a more aggressive pruning strategy.
In particular, BeamParams
has a prune_threshold
parameter which controls the minimum probability that a candidate must have to be kept in the beam. Higher values lead to more aggressive pruning, which significantly reduces the number of language model calls we need to make.
beam = await ByteBeamState.initial(llm, BeamParams(K=5, prune_threshold=0.05))
beam = await beam.prefill(b"An apple a day keeps the ")
beam
Z: -19.598918914794922 Candidates: (1.0000) -19.60: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣
logp_next = await beam.logp_next()
logp_next.pretty().top(5)
key | value |
b'd' | -0.5766762743944795 |
b'b' | -2.8732729803080233 |
b's' | -2.9816068063730867 |
b'w' | -3.3758250127787264 |
b'm' | -3.528177345847574 |
Putting it all together, we can generate a sequence of bytes by repeatedly selecting a next byte from the log probability distribution and advancing the beam by that byte.
One selection strategy is to always select the byte with the highest log probability, which is what greedy
does:
beam = await ByteBeamState.initial(
llm, BeamParams(K=5, prune_threshold=0.05, verbose=True)
)
sampled = await beam.greedy(b"An apple a day keeps the ", steps=12)
sampled
Z: -2.174436330795288 Candidates: (1.0000) -2.17: <|endoftext|>|A Z: -4.501037198697343 Candidates: (0.9977) -4.50: <|endoftext|>|An (0.0023) -10.56: <|endoftext|>|A|n Z: -5.643285751342773 Candidates: (1.0000) -5.64: <|endoftext|>|An|␣ Z: -7.201362133026123 Candidates: (1.0000) -7.20: <|endoftext|>|An|␣a Z: -10.39808464050293 Candidates: (1.0000) -10.40: <|endoftext|>|An|␣ap Z: -10.627063751220703 Candidates: (1.0000) -10.63: <|endoftext|>|An|␣app Z: -12.216539396150903 Candidates: (0.9934) -12.22: <|endoftext|>|An|␣appl (0.0066) -17.23: <|endoftext|>|An|␣app|l Z: -13.993473052978516 Candidates: (1.0000) -13.99: <|endoftext|>|An|␣apple Z: -14.135584831237793 Candidates: (1.0000) -14.14: <|endoftext|>|An|␣apple|␣ Z: -16.980817794799805 Candidates: (1.0000) -16.98: <|endoftext|>|An|␣apple|␣a Z: -17.91916847229004 Candidates: (1.0000) -17.92: <|endoftext|>|An|␣apple|␣a|␣ Z: -18.107898712158203 Candidates: (1.0000) -18.11: <|endoftext|>|An|␣apple|␣a|␣d Z: -18.110923767089844 Candidates: (1.0000) -18.11: <|endoftext|>|An|␣apple|␣a|␣da Z: -18.111148834228516 Candidates: (1.0000) -18.11: <|endoftext|>|An|␣apple|␣a|␣day Z: -18.135374069213867 Candidates: (1.0000) -18.14: <|endoftext|>|An|␣apple|␣a|␣day|␣ Z: -18.454233169555664 Candidates: (1.0000) -18.45: <|endoftext|>|An|␣apple|␣a|␣day|␣k Z: -18.4615478515625 Candidates: (1.0000) -18.46: <|endoftext|>|An|␣apple|␣a|␣day|␣ke Z: -18.469377517700195 Candidates: (1.0000) -18.47: <|endoftext|>|An|␣apple|␣a|␣day|␣kee Z: -18.469377517700195 Candidates: (1.0000) -18.47: <|endoftext|>|An|␣apple|␣a|␣day|␣keep Z: -18.469871520996094 Candidates: (1.0000) -18.47: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps Z: -18.472919464111328 Candidates: (1.0000) -18.47: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣ Z: -19.51340675354004 Candidates: (1.0000) -19.51: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣t Z: -19.53944206237793 Candidates: (1.0000) -19.54: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣th Z: -19.567386627197266 Candidates: (1.0000) -19.57: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the Z: -19.598918914794922 Candidates: (1.0000) -19.60: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣ Z: -20.17568588256836 Candidates: (1.0000) -20.18: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣d Z: -20.21005630493164 Candidates: (1.0000) -20.21: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣do Z: -20.214828491210938 Candidates: (1.0000) -20.21: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣doc Z: -20.21491241455078 Candidates: (1.0000) -20.21: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣doct Z: -20.214920043945312 Candidates: (1.0000) -20.21: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣docto Z: -20.214920043945312 Candidates: (1.0000) -20.21: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣doctor Z: -20.248199462890625 Candidates: (1.0000) -20.25: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣doctor|␣ Z: -20.292394638061523 Candidates: (1.0000) -20.29: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣doctor|␣a Z: -20.30491065979004 Candidates: (1.0000) -20.30: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣doctor|␣aw Z: -20.304912567138672 Candidates: (1.0000) -20.30: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣doctor|␣awa Z: -20.30513572692871 Candidates: (1.0000) -20.31: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣doctor|␣away Z: -21.068042755126953 Candidates: (1.0000) -21.07: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣doctor|␣away|.
b'An apple a day keeps the doctor away.'
We can also sample from the log probability distribution over the next byte:
beam = await ByteBeamState.initial(
llm, BeamParams(K=5, prune_threshold=0.05, verbose=True)
)
sampled = await beam.sample(b"An apple a day keeps the ", steps=12)
sampled
Z: -2.174436330795288 Candidates: (1.0000) -2.17: <|endoftext|>|A Z: -4.501037198697343 Candidates: (0.9977) -4.50: <|endoftext|>|An (0.0023) -10.56: <|endoftext|>|A|n Z: -5.643285751342773 Candidates: (1.0000) -5.64: <|endoftext|>|An|␣ Z: -7.201362133026123 Candidates: (1.0000) -7.20: <|endoftext|>|An|␣a Z: -10.39808464050293 Candidates: (1.0000) -10.40: <|endoftext|>|An|␣ap Z: -10.627063751220703 Candidates: (1.0000) -10.63: <|endoftext|>|An|␣app Z: -12.216539396150903 Candidates: (0.9934) -12.22: <|endoftext|>|An|␣appl (0.0066) -17.23: <|endoftext|>|An|␣app|l Z: -13.993473052978516 Candidates: (1.0000) -13.99: <|endoftext|>|An|␣apple Z: -14.135584831237793 Candidates: (1.0000) -14.14: <|endoftext|>|An|␣apple|␣ Z: -16.980817794799805 Candidates: (1.0000) -16.98: <|endoftext|>|An|␣apple|␣a Z: -17.91916847229004 Candidates: (1.0000) -17.92: <|endoftext|>|An|␣apple|␣a|␣ Z: -18.107898712158203 Candidates: (1.0000) -18.11: <|endoftext|>|An|␣apple|␣a|␣d Z: -18.110923767089844 Candidates: (1.0000) -18.11: <|endoftext|>|An|␣apple|␣a|␣da Z: -18.111148834228516 Candidates: (1.0000) -18.11: <|endoftext|>|An|␣apple|␣a|␣day Z: -18.135374069213867 Candidates: (1.0000) -18.14: <|endoftext|>|An|␣apple|␣a|␣day|␣ Z: -18.454233169555664 Candidates: (1.0000) -18.45: <|endoftext|>|An|␣apple|␣a|␣day|␣k Z: -18.4615478515625 Candidates: (1.0000) -18.46: <|endoftext|>|An|␣apple|␣a|␣day|␣ke Z: -18.469377517700195 Candidates: (1.0000) -18.47: <|endoftext|>|An|␣apple|␣a|␣day|␣kee Z: -18.469377517700195 Candidates: (1.0000) -18.47: <|endoftext|>|An|␣apple|␣a|␣day|␣keep Z: -18.469871520996094 Candidates: (1.0000) -18.47: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps Z: -18.472919464111328 Candidates: (1.0000) -18.47: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣ Z: -19.51340675354004 Candidates: (1.0000) -19.51: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣t Z: -19.53944206237793 Candidates: (1.0000) -19.54: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣th Z: -19.567386627197266 Candidates: (1.0000) -19.57: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the Z: -19.598918914794922 Candidates: (1.0000) -19.60: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣ Z: -24.765562057495117 Candidates: (1.0000) -24.77: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣D Z: -27.6703574912875 Candidates: (0.9409) -27.73: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣Du (0.0591) -30.50: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣D|u Z: -28.567639703725348 Candidates: (0.9906) -28.58: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣Dut (0.0094) -33.24: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣D|ut Z: -28.577476501464844 Candidates: (1.0000) -28.58: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣Dutc Z: -28.577476501464844 Candidates: (1.0000) -28.58: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣Dutch Z: -28.774438858032227 Candidates: (1.0000) -28.77: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣Dutch|␣ Z: -31.404659271240234 Candidates: (1.0000) -31.40: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣Dutch|␣a Z: -32.12074279785156 Candidates: (1.0000) -32.12: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣Dutch|␣aw Z: -32.12383270263672 Candidates: (1.0000) -32.12: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣Dutch|␣awa Z: -32.18830490112305 Candidates: (1.0000) -32.19: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣Dutch|␣away Z: -33.56272506713867 Candidates: (1.0000) -33.56: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣Dutch|␣away|. Z: -34.75253677368164 Candidates: (1.0000) -34.75: <|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣Dutch|␣away|.|\n
b'An apple a day keeps the Dutch away.\n'