Skip to content

product

Product

Bases: Potential

Combine two potential instances via element-wise multiplication (sum in log space).

This class creates a new potential that is the element-wise product of two potentials:

prefix(xs) = p1.prefix(xs) + p2.prefix(xs)
complete(xs) = p1.complete(xs) + p2.complete(xs)
logw_next(x | xs) = p1.logw_next(x | xs) + p2.logw_next(x | xs)

The new potential's vocabulary is the intersection of the two potentials' vocabularies.

This class inherits all methods from Potential, see there for method documentation.

Attributes:

Name Type Description
p1 Potential

The first potential instance.

p2 Potential

The second potential instance.

token_type str

The type of tokens that this product potential operates on.

vocab list

The common vocabulary shared between the two potentials.

Warning

Be careful when taking products of potentials with minimal vocabulary overlap. The resulting potential will only operate on tokens present in both vocabularies.

Source code in genlm/control/potential/product.py
class Product(Potential):
    """
    Combine two potential instances via element-wise multiplication (sum in log space).

    This class creates a new potential that is the element-wise product of two potentials:
    ```
    prefix(xs) = p1.prefix(xs) + p2.prefix(xs)
    complete(xs) = p1.complete(xs) + p2.complete(xs)
    logw_next(x | xs) = p1.logw_next(x | xs) + p2.logw_next(x | xs)
    ```

    The new potential's vocabulary is the intersection of the two potentials' vocabularies.

    This class inherits all methods from [`Potential`][genlm.control.potential.base.Potential],
    see there for method documentation.

    Attributes:
        p1 (Potential): The first potential instance.
        p2 (Potential): The second potential instance.
        token_type (str): The type of tokens that this product potential operates on.
        vocab (list): The common vocabulary shared between the two potentials.

    Warning:
        Be careful when taking products of potentials with minimal vocabulary overlap.
        The resulting potential will only operate on tokens present in both vocabularies.
    """

    def __init__(self, p1, p2):
        """Initialize a Product potential.

        Args:
            p1 (Potential): First potential
            p2 (Potential): Second potential
        """
        self.p1 = p1
        self.p2 = p2

        if self.p1.token_type == self.p2.token_type:
            self.token_type = self.p1.token_type
        else:
            raise ValueError(
                "Potentials in product must have the same token type. "
                f"Got {self.p1.token_type} and {self.p2.token_type}."
                + (
                    "\nMaybe you forgot to coerce the potentials to the same token type? See `Coerce`."
                    if (
                        self.p1.token_type.is_iterable_of(self.p2.token_type)
                        or self.p2.token_type.is_iterable_of(self.p1.token_type)
                    )
                    else ""
                )
            )

        common_vocab = list(set(p1.vocab) & set(p2.vocab))
        if not common_vocab:
            raise ValueError("Potentials in product must share a common vocabulary")

        # Check for small vocabulary overlap
        threshold = 0.1
        for potential, name in [(p1, "p1"), (p2, "p2")]:
            overlap_ratio = len(common_vocab) / len(potential.vocab)
            if overlap_ratio < threshold:
                warnings.warn(
                    f"Common vocabulary ({len(common_vocab)} tokens) is less than {threshold * 100}% "
                    f"of {name}'s ({potential!r}) vocabulary ({len(potential.vocab)} tokens). "
                    "This Product potential only operates on this relatively small subset of tokens.",
                    RuntimeWarning,
                )

        super().__init__(common_vocab, token_type=self.token_type)

        # For fast products of weights
        self.v1_idxs = [p1.lookup[token] for token in self.vocab_eos]
        self.v2_idxs = [p2.lookup[token] for token in self.vocab_eos]

    async def prefix(self, context):
        w1 = await self.p1.prefix(context)
        if w1 == float("-inf"):
            return float("-inf")
        w2 = await self.p2.prefix(context)
        return w1 + w2

    async def complete(self, context):
        w1 = await self.p1.complete(context)
        if w1 == float("-inf"):
            return float("-inf")
        w2 = await self.p2.complete(context)
        return w1 + w2

    async def batch_complete(self, contexts):
        W1, W2 = await asyncio.gather(
            self.p1.batch_complete(contexts), self.p2.batch_complete(contexts)
        )
        return W1 + W2

    async def batch_prefix(self, contexts):
        W1, W2 = await asyncio.gather(
            self.p1.batch_prefix(contexts), self.p2.batch_prefix(contexts)
        )
        return W1 + W2

    async def logw_next(self, context):
        W1, W2 = await asyncio.gather(
            self.p1.logw_next(context), self.p2.logw_next(context)
        )
        return self.make_lazy_weights(
            W1.weights[self.v1_idxs] + W2.weights[self.v2_idxs]
        )

    async def batch_logw_next(self, contexts):
        Ws1, Ws2 = await asyncio.gather(
            self.p1.batch_logw_next(contexts), self.p2.batch_logw_next(contexts)
        )
        return [
            self.make_lazy_weights(
                Ws1[n].weights[self.v1_idxs] + Ws2[n].weights[self.v2_idxs]
            )
            for n in range(len(contexts))
        ]

    def spawn(self, p1_opts=None, p2_opts=None):
        return Product(
            self.p1.spawn(**(p1_opts or {})),
            self.p2.spawn(**(p2_opts or {})),
        )

    def __repr__(self):
        return f"Product({self.p1!r}, {self.p2!r})"

__init__(p1, p2)

Initialize a Product potential.

Parameters:

Name Type Description Default
p1 Potential

First potential

required
p2 Potential

Second potential

required
Source code in genlm/control/potential/product.py
def __init__(self, p1, p2):
    """Initialize a Product potential.

    Args:
        p1 (Potential): First potential
        p2 (Potential): Second potential
    """
    self.p1 = p1
    self.p2 = p2

    if self.p1.token_type == self.p2.token_type:
        self.token_type = self.p1.token_type
    else:
        raise ValueError(
            "Potentials in product must have the same token type. "
            f"Got {self.p1.token_type} and {self.p2.token_type}."
            + (
                "\nMaybe you forgot to coerce the potentials to the same token type? See `Coerce`."
                if (
                    self.p1.token_type.is_iterable_of(self.p2.token_type)
                    or self.p2.token_type.is_iterable_of(self.p1.token_type)
                )
                else ""
            )
        )

    common_vocab = list(set(p1.vocab) & set(p2.vocab))
    if not common_vocab:
        raise ValueError("Potentials in product must share a common vocabulary")

    # Check for small vocabulary overlap
    threshold = 0.1
    for potential, name in [(p1, "p1"), (p2, "p2")]:
        overlap_ratio = len(common_vocab) / len(potential.vocab)
        if overlap_ratio < threshold:
            warnings.warn(
                f"Common vocabulary ({len(common_vocab)} tokens) is less than {threshold * 100}% "
                f"of {name}'s ({potential!r}) vocabulary ({len(potential.vocab)} tokens). "
                "This Product potential only operates on this relatively small subset of tokens.",
                RuntimeWarning,
            )

    super().__init__(common_vocab, token_type=self.token_type)

    # For fast products of weights
    self.v1_idxs = [p1.lookup[token] for token in self.vocab_eos]
    self.v2_idxs = [p2.lookup[token] for token in self.vocab_eos]