Bases: Potential
Coerce a potential to operate on another vocabulary.
This class allows a potential to be adapted to work with a different set of tokens,
defined by a target vocabulary and coersion function.
This class inherits all methods from Potential
.
Each method delegates to the corresponding method of the underlying potential, but first
maps any input token sequences from the target vocabulary to the original potential's vocabulary
using the coercion function.
Formally, if \(f\) is the coercion function, then for any sequence \(x_1, \ldots, x_n\) of tokens from the target vocabulary,
$$
\textsf{Coerced.prefix}(x_1, \ldots, x_n) = \textsf{Coerced.potential.prefix}(f(x_1, \ldots, x_n))
$$
\[
\textsf{Coerced.complete}(x_1, \ldots, x_n) = \textsf{Coerced.potential.complete}(f(x_1, \ldots, x_n))
\]
Attributes:
Name |
Type |
Description |
potential |
Potential
|
The original potential instance that is being coerced.
|
f |
callable
|
A function that maps sequences of tokens from the target vocabulary to sequences of tokens from
the original potential's vocabulary.
|
Note
The coerced potential's vocabulary will by default be pruned to only include tokens that can be mapped to the original potential's vocabulary
via the coercion function (i.e. set(f([x])) <= set(potential.vocab)
). If no such tokens are found, a ValueError
is raised.
This behavior can be overridden by setting prune=False
, in which case the coerced potential's vocabulary will include all tokens from the target vocabulary.
Source code in genlm/control/potential/coerce.py
| class Coerced(Potential):
"""
Coerce a potential to operate on another vocabulary.
This class allows a potential to be adapted to work with a different set of tokens,
defined by a target vocabulary and coersion function.
This class inherits all methods from [`Potential`][genlm.control.potential.base.Potential].
Each method delegates to the corresponding method of the underlying potential, but first
maps any input token sequences from the target vocabulary to the original potential's vocabulary
using the coercion function.
Formally, if $f$ is the coercion function, then for any sequence $x_1, \\ldots, x_n$ of tokens from the target vocabulary,
$$
\\textsf{Coerced.prefix}(x_1, \\ldots, x_n) = \\textsf{Coerced.potential.prefix}(f(x_1, \\ldots, x_n))
$$
$$
\\textsf{Coerced.complete}(x_1, \\ldots, x_n) = \\textsf{Coerced.potential.complete}(f(x_1, \\ldots, x_n))
$$
Attributes:
potential (Potential): The original potential instance that is being coerced.
f (callable): A function that maps sequences of tokens from the target vocabulary to sequences of tokens from
the original potential's vocabulary.
Note:
The coerced potential's vocabulary will by default be pruned to only include tokens that can be mapped to the original potential's vocabulary
via the coercion function (i.e. `set(f([x])) <= set(potential.vocab)`). If no such tokens are found, a `ValueError` is raised.
This behavior can be overridden by setting `prune=False`, in which case the coerced potential's vocabulary will include all tokens from the target vocabulary.
"""
def __init__(self, potential, target_vocab, f, prune=True):
"""
Initialize a Coerced potential.
Args:
potential (Potential): The original potential instance that is being coerced.
target_vocab (list): The target vocabulary that the potential will operate on.
Each element of `target_vocab` must be hashable.
f (callable): A function that maps iterables of tokens from the target vocabulary
to the original potential's vocabulary.
prune (bool): Whether to prune the coerced potential's vocabulary to only include tokens that can be mapped to the original potential's vocabulary.
If `False`, the coerced potential's vocabulary will include all tokens from the target vocabulary.
Raises:
ValueError: If no valid tokens are found in the target vocabulary that can be mapped to the original potential's vocabulary.
"""
self.potential = potential
self.f = f
if prune:
tokens = []
for target_token in target_vocab:
base_token = f([target_token])
if set(base_token) <= set(potential.vocab):
tokens.append(target_token)
else:
tokens = target_vocab
if not tokens:
raise ValueError("No valid tokens found in target vocabulary")
super().__init__(tokens)
def _batch_f(self, contexts):
return [self.f(context) for context in contexts]
async def complete(self, context):
return await self.potential.complete(context=self.f(context))
async def prefix(self, context):
return await self.potential.prefix(context=self.f(context))
async def logw_next(self, context):
Ws = self.alloc_logws()
ctx = self.f(context)
ctx_w = await self.potential.prefix(ctx)
Ws[-1] = await self.potential.complete(ctx) - ctx_w
exts = [self.f(chain(context, [x])) for x in self.vocab] # slow!!
Ws[:-1] = await self.potential.batch_prefix(exts) - ctx_w
return self.make_lazy_weights(Ws)
async def batch_complete(self, contexts):
return await self.potential.batch_complete(contexts=self._batch_f(contexts))
async def batch_prefix(self, contexts):
return await self.potential.batch_prefix(contexts=self._batch_f(contexts))
async def batch_logw_next(self, contexts):
return await asyncio.gather(*[self.logw_next(context) for context in contexts])
def __repr__(self):
return f"{self.__class__.__name__}({self.potential!r})"
|
__init__(potential, target_vocab, f, prune=True)
Initialize a Coerced potential.
Parameters:
Name |
Type |
Description |
Default |
potential
|
Potential
|
The original potential instance that is being coerced.
|
required
|
target_vocab
|
list
|
The target vocabulary that the potential will operate on.
Each element of target_vocab must be hashable.
|
required
|
f
|
callable
|
A function that maps iterables of tokens from the target vocabulary
to the original potential's vocabulary.
|
required
|
prune
|
bool
|
Whether to prune the coerced potential's vocabulary to only include tokens that can be mapped to the original potential's vocabulary.
If False , the coerced potential's vocabulary will include all tokens from the target vocabulary.
|
True
|
Raises:
Type |
Description |
ValueError
|
If no valid tokens are found in the target vocabulary that can be mapped to the original potential's vocabulary.
|
Source code in genlm/control/potential/coerce.py
| def __init__(self, potential, target_vocab, f, prune=True):
"""
Initialize a Coerced potential.
Args:
potential (Potential): The original potential instance that is being coerced.
target_vocab (list): The target vocabulary that the potential will operate on.
Each element of `target_vocab` must be hashable.
f (callable): A function that maps iterables of tokens from the target vocabulary
to the original potential's vocabulary.
prune (bool): Whether to prune the coerced potential's vocabulary to only include tokens that can be mapped to the original potential's vocabulary.
If `False`, the coerced potential's vocabulary will include all tokens from the target vocabulary.
Raises:
ValueError: If no valid tokens are found in the target vocabulary that can be mapped to the original potential's vocabulary.
"""
self.potential = potential
self.f = f
if prune:
tokens = []
for target_token in target_vocab:
base_token = f([target_token])
if set(base_token) <= set(potential.vocab):
tokens.append(target_token)
else:
tokens = target_vocab
if not tokens:
raise ValueError("No valid tokens found in target vocabulary")
super().__init__(tokens)
|