Skip to content

wfsa

WFSA

Bases: WFSA

Weighted finite-state automata where weights are a field (e.g., real-valued).

Source code in genlm/grammar/wfsa/field_wfsa.py
class WFSA(base.WFSA):
    """
    Weighted finite-state automata where weights are a field (e.g., real-valued).
    """

    def __init__(self, R=Float):
        super().__init__(R=R)

    def __hash__(self):
        return hash(self.simple)

    def threshold(self, threshold):
        "Drop init, arcs, final below a given abs-threshold."
        m = self.__class__(self.R)
        for q, w in self.I:
            if abs(w) >= threshold:
                m.add_I(q, w)
        for i, a, j, w in self.arcs():
            if abs(w) >= threshold:
                m.add_arc(i, a, j, w)
        for q, w in self.F:
            if abs(w) >= threshold:
                m.add_F(q, w)
        return m

    def graphviz(
        self,
        fmt=lambda x: f"{round(x, 3):g}" if isinstance(x, (float, int)) else str(x),
        **kwargs,
    ):  # pylint: disable=arguments-differ
        return super().graphviz(fmt=fmt, **kwargs)

    @cached_property
    def simple(self):
        self = self.epsremove.renumber

        S = self.dim
        start = np.full(S, self.R.zero)
        arcs = {a: np.full((S, S), self.R.zero) for a in self.alphabet}
        stop = np.full(S, self.R.zero)

        for i, w in self.I:
            start[i] += w
        for i, a, j, w in self.arcs():
            arcs[a][i, j] += w
        for i, w in self.F:
            stop[i] += w

        assert EPSILON not in arcs

        return Simple(start, arcs, stop)

    def __eq__(self, other):
        return self.simple == other.simple

    def counterexample(self, other):
        return self.simple.counterexample(other.simple)

    @cached_property
    def min(self):
        return self.simple.min.to_wfsa()

    #    @cached_property
    #    def epsremove(self):
    #        return self.simple.to_wfsa()

    def multiplicity(self, m):
        return WFSA.lift(EPSILON, m) * self

    @classmethod
    def lift(cls, x, w, R=None):
        if R is None:
            R = Float
        m = cls(R=R)
        m.add_I(0, R.one)
        m.add_arc(0, x, 1, w)
        m.add_F(1, R.one)
        return m

threshold(threshold)

Drop init, arcs, final below a given abs-threshold.

Source code in genlm/grammar/wfsa/field_wfsa.py
def threshold(self, threshold):
    "Drop init, arcs, final below a given abs-threshold."
    m = self.__class__(self.R)
    for q, w in self.I:
        if abs(w) >= threshold:
            m.add_I(q, w)
    for i, a, j, w in self.arcs():
        if abs(w) >= threshold:
            m.add_arc(i, a, j, w)
    for q, w in self.F:
        if abs(w) >= threshold:
            m.add_F(q, w)
    return m