Skip to content

util

LazyWeights

A class to represent weights in a lazy manner, allowing for efficient operations on potentially large weight arrays without immediate materialization.

Attributes:

Name Type Description
weights ndarray

The weights associated with the tokens.

encode dict

A mapping from tokens to their corresponding indices in the weights array.

decode list

A list of tokens corresponding to the weights.

is_log bool

A flag indicating whether the weights are in log space.

Source code in genlm/control/util.py
class LazyWeights:
    """
    A class to represent weights in a lazy manner, allowing for efficient operations
    on potentially large weight arrays without immediate materialization.

    Attributes:
        weights (np.ndarray): The weights associated with the tokens.
        encode (dict): A mapping from tokens to their corresponding indices in the weights array.
        decode (list): A list of tokens corresponding to the weights.
        is_log (bool): A flag indicating whether the weights are in log space.
    """

    def __init__(self, weights, encode, decode, log=True):
        """
        Initialize the LazyWeights instance.

        Args:
            weights (np.ndarray): The weights associated with the tokens.
            encode (dict): A mapping from tokens to their corresponding indices in the weights array.
            decode (list): A list of tokens corresponding to the weights.
            log (bool, optional): Indicates if the weights are in log space. Defaults to True.

        Raises:
            AssertionError: If the lengths of weights and decode or encode do not match.
        """
        assert len(weights) == len(decode)
        assert len(encode) == len(decode)

        self.weights = weights
        self.encode = encode
        self.decode = decode
        self.is_log = log

    def __getitem__(self, token):
        """
        Retrieve the weight for a given token.

        Args:
            token (Any): The token for which to retrieve the weight.

        Returns:
            (float): The weight of the token, or -inf/0 if the token is not found.
        """
        if token not in self.encode:
            return float("-inf") if self.is_log else 0
        return self.weights[self.encode[token]]

    def __len__(self):
        return len(self.weights)

    def __array__(self):
        raise NotImplementedError(
            "LazyWeights cannot be converted to a numpy array. "
            "If you want to combine multiple LazyWeights, use their weights attribute directly."
        )

    def keys(self):
        """Return the list of tokens (keys) in the vocabulary."""
        return self.decode

    def values(self):
        """Return the weights associated with the tokens."""
        return self.weights

    def items(self):
        """Return a zip of tokens and weights."""
        return zip(self.keys(), self.values())

    def normalize(self):
        """
        Normalize the weights.

        Normalization is performed using log-space arithmetic when weights are logarithmic,
        or standard arithmetic otherwise.

        Returns:
            (LazyWeights): A new LazyWeights instance with normalized weights.
        """
        if self.is_log:
            return self.spawn(self.weights - logsumexp(self.weights))
        else:
            return self.spawn(self.weights / np.sum(self.weights))

    def exp(self):
        """
        Exponentiate the weights. This operation can only be performed when weights are in log space.

        Returns:
            (LazyWeights): A new LazyWeights instance with exponentiated weights.

        Raises:
            AssertionError: If the weights are not in log space.
        """
        assert self.is_log, "Weights must be in log space to exponentiate"
        return self.spawn(np.exp(self.weights), log=False)

    def log(self):
        """
        Take the logarithm of the weights. This operation can only be performed when weights are in regular space.

        Returns:
            (LazyWeights): A new LazyWeights instance with logarithmic weights.

        Raises:
            AssertionError: If the weights are already in log space.
        """
        assert not self.is_log, "Weights must be in regular space to take the logarithm"
        return self.spawn(np.log(self.weights), log=True)

    def sum(self):
        """
        Sum the weights.

        Summation is performed using log-space arithmetic when weights are logarithmic,
        or standard arithmetic otherwise.

        Returns:
            (float): The sum of the weights, either in log space or regular space.
        """
        if self.is_log:
            return logsumexp(self.weights)
        else:
            return np.sum(self.weights)

    def spawn(self, new_weights, log=None):
        """
        Create a new LazyWeights instance over the same vocabulary with new weights.

        Args:
            new_weights (np.ndarray): The new weights for the LazyWeights instance.
            log (bool, optional): Indicates if the new weights are in log space. Defaults to None.

        Returns:
            (LazyWeights): A new LazyWeights instance.
        """
        if log is None:
            log = self.is_log
        return LazyWeights(
            weights=new_weights, encode=self.encode, decode=self.decode, log=log
        )

    def materialize(self, top=None):
        """
        Materialize the weights into a chart.

        Args:
            top (int, optional): The number of top weights to materialize. Defaults to None.

        Returns:
            (Chart): A chart representation of the weights.
        """
        weights = self.weights
        if top is not None:
            top_ws = weights.argsort()[-int(top) :]
        else:
            top_ws = weights.argsort()

        semiring = Log if self.is_log else Float

        chart = semiring.chart()
        for i in reversed(top_ws):
            chart[self.decode[i]] = weights[i]

        return chart

    def __repr__(self):
        return repr(self.materialize())

    def assert_equal(self, other, **kwargs):
        """
        Assert that two LazyWeights instances are equal.

        This method asserts that the two LazyWeights instances have the same vocabulary
        (in identical order) and that their weights are numerically close.

        Args:
            other (LazyWeights): The other LazyWeights instance to compare.
            **kwargs (dict): Additional arguments for np.testing.assert_allclose (e.g., rtol, atol).
        """
        assert self.decode == other.decode
        np.testing.assert_allclose(self.weights, other.weights, **kwargs)

    def assert_equal_unordered(self, other, **kwargs):
        """
        Assert that two LazyWeights instances are equal, ignoring vocabularyorder.

        Args:
            other (LazyWeights): The other LazyWeights instance to compare.
            **kwargs (dict): Additional arguments for np.isclose (e.g., rtol, atol).
        """
        assert set(self.decode) == set(other.decode), "keys do not match"

        for x in self.decode:
            have, want = self[x], other[x]
            assert np.isclose(have, want, **kwargs), f"{x}: {have} != {want}"

__init__(weights, encode, decode, log=True)

Initialize the LazyWeights instance.

Parameters:

Name Type Description Default
weights ndarray

The weights associated with the tokens.

required
encode dict

A mapping from tokens to their corresponding indices in the weights array.

required
decode list

A list of tokens corresponding to the weights.

required
log bool

Indicates if the weights are in log space. Defaults to True.

True

Raises:

Type Description
AssertionError

If the lengths of weights and decode or encode do not match.

Source code in genlm/control/util.py
def __init__(self, weights, encode, decode, log=True):
    """
    Initialize the LazyWeights instance.

    Args:
        weights (np.ndarray): The weights associated with the tokens.
        encode (dict): A mapping from tokens to their corresponding indices in the weights array.
        decode (list): A list of tokens corresponding to the weights.
        log (bool, optional): Indicates if the weights are in log space. Defaults to True.

    Raises:
        AssertionError: If the lengths of weights and decode or encode do not match.
    """
    assert len(weights) == len(decode)
    assert len(encode) == len(decode)

    self.weights = weights
    self.encode = encode
    self.decode = decode
    self.is_log = log

__getitem__(token)

Retrieve the weight for a given token.

Parameters:

Name Type Description Default
token Any

The token for which to retrieve the weight.

required

Returns:

Type Description
float

The weight of the token, or -inf/0 if the token is not found.

Source code in genlm/control/util.py
def __getitem__(self, token):
    """
    Retrieve the weight for a given token.

    Args:
        token (Any): The token for which to retrieve the weight.

    Returns:
        (float): The weight of the token, or -inf/0 if the token is not found.
    """
    if token not in self.encode:
        return float("-inf") if self.is_log else 0
    return self.weights[self.encode[token]]

keys()

Return the list of tokens (keys) in the vocabulary.

Source code in genlm/control/util.py
def keys(self):
    """Return the list of tokens (keys) in the vocabulary."""
    return self.decode

values()

Return the weights associated with the tokens.

Source code in genlm/control/util.py
def values(self):
    """Return the weights associated with the tokens."""
    return self.weights

items()

Return a zip of tokens and weights.

Source code in genlm/control/util.py
def items(self):
    """Return a zip of tokens and weights."""
    return zip(self.keys(), self.values())

normalize()

Normalize the weights.

Normalization is performed using log-space arithmetic when weights are logarithmic, or standard arithmetic otherwise.

Returns:

Type Description
LazyWeights

A new LazyWeights instance with normalized weights.

Source code in genlm/control/util.py
def normalize(self):
    """
    Normalize the weights.

    Normalization is performed using log-space arithmetic when weights are logarithmic,
    or standard arithmetic otherwise.

    Returns:
        (LazyWeights): A new LazyWeights instance with normalized weights.
    """
    if self.is_log:
        return self.spawn(self.weights - logsumexp(self.weights))
    else:
        return self.spawn(self.weights / np.sum(self.weights))

exp()

Exponentiate the weights. This operation can only be performed when weights are in log space.

Returns:

Type Description
LazyWeights

A new LazyWeights instance with exponentiated weights.

Raises:

Type Description
AssertionError

If the weights are not in log space.

Source code in genlm/control/util.py
def exp(self):
    """
    Exponentiate the weights. This operation can only be performed when weights are in log space.

    Returns:
        (LazyWeights): A new LazyWeights instance with exponentiated weights.

    Raises:
        AssertionError: If the weights are not in log space.
    """
    assert self.is_log, "Weights must be in log space to exponentiate"
    return self.spawn(np.exp(self.weights), log=False)

log()

Take the logarithm of the weights. This operation can only be performed when weights are in regular space.

Returns:

Type Description
LazyWeights

A new LazyWeights instance with logarithmic weights.

Raises:

Type Description
AssertionError

If the weights are already in log space.

Source code in genlm/control/util.py
def log(self):
    """
    Take the logarithm of the weights. This operation can only be performed when weights are in regular space.

    Returns:
        (LazyWeights): A new LazyWeights instance with logarithmic weights.

    Raises:
        AssertionError: If the weights are already in log space.
    """
    assert not self.is_log, "Weights must be in regular space to take the logarithm"
    return self.spawn(np.log(self.weights), log=True)

sum()

Sum the weights.

Summation is performed using log-space arithmetic when weights are logarithmic, or standard arithmetic otherwise.

Returns:

Type Description
float

The sum of the weights, either in log space or regular space.

Source code in genlm/control/util.py
def sum(self):
    """
    Sum the weights.

    Summation is performed using log-space arithmetic when weights are logarithmic,
    or standard arithmetic otherwise.

    Returns:
        (float): The sum of the weights, either in log space or regular space.
    """
    if self.is_log:
        return logsumexp(self.weights)
    else:
        return np.sum(self.weights)

spawn(new_weights, log=None)

Create a new LazyWeights instance over the same vocabulary with new weights.

Parameters:

Name Type Description Default
new_weights ndarray

The new weights for the LazyWeights instance.

required
log bool

Indicates if the new weights are in log space. Defaults to None.

None

Returns:

Type Description
LazyWeights

A new LazyWeights instance.

Source code in genlm/control/util.py
def spawn(self, new_weights, log=None):
    """
    Create a new LazyWeights instance over the same vocabulary with new weights.

    Args:
        new_weights (np.ndarray): The new weights for the LazyWeights instance.
        log (bool, optional): Indicates if the new weights are in log space. Defaults to None.

    Returns:
        (LazyWeights): A new LazyWeights instance.
    """
    if log is None:
        log = self.is_log
    return LazyWeights(
        weights=new_weights, encode=self.encode, decode=self.decode, log=log
    )

materialize(top=None)

Materialize the weights into a chart.

Parameters:

Name Type Description Default
top int

The number of top weights to materialize. Defaults to None.

None

Returns:

Type Description
Chart

A chart representation of the weights.

Source code in genlm/control/util.py
def materialize(self, top=None):
    """
    Materialize the weights into a chart.

    Args:
        top (int, optional): The number of top weights to materialize. Defaults to None.

    Returns:
        (Chart): A chart representation of the weights.
    """
    weights = self.weights
    if top is not None:
        top_ws = weights.argsort()[-int(top) :]
    else:
        top_ws = weights.argsort()

    semiring = Log if self.is_log else Float

    chart = semiring.chart()
    for i in reversed(top_ws):
        chart[self.decode[i]] = weights[i]

    return chart

assert_equal(other, **kwargs)

Assert that two LazyWeights instances are equal.

This method asserts that the two LazyWeights instances have the same vocabulary (in identical order) and that their weights are numerically close.

Parameters:

Name Type Description Default
other LazyWeights

The other LazyWeights instance to compare.

required
**kwargs dict

Additional arguments for np.testing.assert_allclose (e.g., rtol, atol).

{}
Source code in genlm/control/util.py
def assert_equal(self, other, **kwargs):
    """
    Assert that two LazyWeights instances are equal.

    This method asserts that the two LazyWeights instances have the same vocabulary
    (in identical order) and that their weights are numerically close.

    Args:
        other (LazyWeights): The other LazyWeights instance to compare.
        **kwargs (dict): Additional arguments for np.testing.assert_allclose (e.g., rtol, atol).
    """
    assert self.decode == other.decode
    np.testing.assert_allclose(self.weights, other.weights, **kwargs)

assert_equal_unordered(other, **kwargs)

Assert that two LazyWeights instances are equal, ignoring vocabularyorder.

Parameters:

Name Type Description Default
other LazyWeights

The other LazyWeights instance to compare.

required
**kwargs dict

Additional arguments for np.isclose (e.g., rtol, atol).

{}
Source code in genlm/control/util.py
def assert_equal_unordered(self, other, **kwargs):
    """
    Assert that two LazyWeights instances are equal, ignoring vocabularyorder.

    Args:
        other (LazyWeights): The other LazyWeights instance to compare.
        **kwargs (dict): Additional arguments for np.isclose (e.g., rtol, atol).
    """
    assert set(self.decode) == set(other.decode), "keys do not match"

    for x in self.decode:
        have, want = self[x], other[x]
        assert np.isclose(have, want, **kwargs), f"{x}: {have} != {want}"

load_trie(V, backend=None, **kwargs)

Load a TokenCharacterTrie.

Parameters:

Name Type Description Default
V list

The vocabulary.

required
backend str

The backend to use for trie construction. Defaults to None.

None
**kwargs dict

Additional arguments for the trie construction.

{}

Returns:

Type Description
TokenCharacterTrie

A trie instance.

Source code in genlm/control/util.py
def load_trie(V, backend=None, **kwargs):
    """
    Load a TokenCharacterTrie.

    Args:
        V (list): The vocabulary.
        backend (str, optional): The backend to use for trie construction. Defaults to None.
        **kwargs (dict): Additional arguments for the trie construction.

    Returns:
        (TokenCharacterTrie): A trie instance.
    """
    import torch

    if backend is None:
        backend = "parallel" if torch.cuda.is_available() else "sequential"

    if backend == "parallel":
        from genlm.backend.trie import ParallelTokenCharacterTrie

        return ParallelTokenCharacterTrie(V, **kwargs)
    else:
        from genlm.backend.trie import TokenCharacterTrie

        return TokenCharacterTrie(V, **kwargs)

load_async_trie(V, backend=None, **kwargs)

Load an AsyncTokenCharacterTrie. This is a TokenCharacterTrie that automatically batches weight_sum and weight_max requests.

Parameters:

Name Type Description Default
V list

The vocabulary.

required
backend str

The backend to use for trie construction. Defaults to None.

None
**kwargs dict

Additional arguments for the trie construction.

{}

Returns:

Type Description
AsyncTokenCharacterTrie

An async trie instance.

Source code in genlm/control/util.py
def load_async_trie(V, backend=None, **kwargs):
    """
    Load an AsyncTokenCharacterTrie. This is a TokenCharacterTrie that
    automatically batches weight_sum and weight_max requests.

    Args:
        V (list): The vocabulary.
        backend (str, optional): The backend to use for trie construction. Defaults to None.
        **kwargs (dict): Additional arguments for the trie construction.

    Returns:
        (AsyncTokenCharacterTrie): An async trie instance.
    """
    from genlm.backend.trie import AsyncTokenCharacterTrie

    return AsyncTokenCharacterTrie(load_trie(V, backend, **kwargs))

fast_sample_logprobs(logprobs, size=1)

Sample indices from an array of log probabilities using the Gumbel-max trick.

Parameters:

Name Type Description Default
logprobs ndarray

Array of log probabilities

required
size int

Number of samples to draw

1

Returns:

Type Description
ndarray

Array of sampled indices

Note

This is much faster than np.random.choice for large arrays since it avoids normalizing probabilities and uses vectorized operations.

Source code in genlm/control/util.py
def fast_sample_logprobs(logprobs: np.ndarray, size: int = 1) -> np.ndarray:
    """Sample indices from an array of log probabilities using the Gumbel-max trick.

    Args:
        logprobs: Array of log probabilities
        size (int): Number of samples to draw

    Returns:
        (np.ndarray): Array of sampled indices

    Note:
        This is much faster than np.random.choice for large arrays since it avoids
        normalizing probabilities and uses vectorized operations.
    """
    noise = -np.log(-np.log(np.random.random((size, len(logprobs)))))
    return (logprobs + noise).argmax(axis=1)

fast_sample_lazyweights(lazyweights)

Sample a token from a LazyWeights instance using the Gumbel-max trick.

Parameters:

Name Type Description Default
lazyweights LazyWeights

A LazyWeights instance

required

Returns:

Type Description
Any

Sampled token

Source code in genlm/control/util.py
def fast_sample_lazyweights(lazyweights):
    """Sample a token from a LazyWeights instance using the Gumbel-max trick.

    Args:
        lazyweights (LazyWeights): A LazyWeights instance

    Returns:
        (Any): Sampled token
    """
    assert lazyweights.is_log
    token_id = fast_sample_logprobs(lazyweights.weights, size=1)[0]
    return lazyweights.decode[token_id]