Skip to content

grammar

BoolCFGLM

Bases: LM

Language model interface for Boolean-weighted CFGs.

Uses Earley's algorithm or CKY for inference. The grammar is converted to use Boolean weights if needed, where positive weights become True and zero/negative weights become False.

Parameters:

Name Type Description Default
cfg CFG

The context-free grammar to use

required
alg str

Parsing algorithm to use - either 'earley' or 'cky'

'earley'

Raises:

Type Description
ValueError

If alg is not 'earley' or 'cky'

Source code in genlm/grammar/cfglm.py
class BoolCFGLM(LM):
    """Language model interface for Boolean-weighted CFGs.

    Uses Earley's algorithm or CKY for inference. The grammar is converted to use
    Boolean weights if needed, where positive weights become True and zero/negative
    weights become False.

    Args:
        cfg (CFG): The context-free grammar to use
        alg (str): Parsing algorithm to use - either 'earley' or 'cky'

    Raises:
        ValueError: If alg is not 'earley' or 'cky'
    """

    def __init__(self, cfg, alg="earley"):
        """Initialize a BoolCFGLM.

        Args:
            cfg (CFG): The context-free grammar to use as the language model
            alg (str): Parsing algorithm to use - either 'earley' or 'cky'

        Raises:
            ValueError: If alg is not 'earley' or 'cky'
        """
        if EOS not in cfg.V:
            cfg = add_EOS(cfg, eos=EOS)
        if cfg.R != Boolean:
            cfg = cfg.map_values(lambda x: Boolean(x > 0), Boolean)
        if alg == "earley":
            from genlm.grammar.parse.earley import Earley

            self.model = Earley(cfg.prefix_grammar)
        elif alg == "cky":
            from genlm.grammar.parse.cky import CKYLM

            self.model = CKYLM(cfg)
        else:
            raise ValueError(f"unrecognized option {alg}")
        super().__init__(eos=EOS, V=cfg.V)

    def p_next(self, context):
        """Compute next token probabilities given a context.

        Args:
            context (sequence): The conditioning context

        Returns:
            (Float.chart): The next token weights

        Raises:
            AssertionError: If context contains out-of-vocabulary tokens
        """
        assert set(context) <= self.V, f"OOVs detected: {set(context) - self.V}"
        p = self.model.next_token_weights(self.model.chart(context)).trim()
        return Float.chart({w: 1 for w in p})

    def __call__(self, context):
        """Check if a context is possible under this grammar.

        Args:
            context (sequence): The context to check

        Returns:
            (bool): True if the context has non-zero weight
        """
        return float(super().__call__(context) > 0)

    def clear_cache(self):
        """Clear any cached computations."""
        self.model.clear_cache()

    @classmethod
    def from_string(cls, x, semiring=Boolean, **kwargs):
        """Create a BoolCFGLM from a string representation of a grammar.

        Args:
            x (str): The grammar string
            semiring: The semiring for weights (default: Boolean)
            **kwargs: Additional arguments passed to __init__

        Returns:
            (BoolCFGLM): A new language model
        """
        return cls(CFG.from_string(x, semiring), **kwargs)

__call__(context)

Check if a context is possible under this grammar.

Parameters:

Name Type Description Default
context sequence

The context to check

required

Returns:

Type Description
bool

True if the context has non-zero weight

Source code in genlm/grammar/cfglm.py
def __call__(self, context):
    """Check if a context is possible under this grammar.

    Args:
        context (sequence): The context to check

    Returns:
        (bool): True if the context has non-zero weight
    """
    return float(super().__call__(context) > 0)

__init__(cfg, alg='earley')

Initialize a BoolCFGLM.

Parameters:

Name Type Description Default
cfg CFG

The context-free grammar to use as the language model

required
alg str

Parsing algorithm to use - either 'earley' or 'cky'

'earley'

Raises:

Type Description
ValueError

If alg is not 'earley' or 'cky'

Source code in genlm/grammar/cfglm.py
def __init__(self, cfg, alg="earley"):
    """Initialize a BoolCFGLM.

    Args:
        cfg (CFG): The context-free grammar to use as the language model
        alg (str): Parsing algorithm to use - either 'earley' or 'cky'

    Raises:
        ValueError: If alg is not 'earley' or 'cky'
    """
    if EOS not in cfg.V:
        cfg = add_EOS(cfg, eos=EOS)
    if cfg.R != Boolean:
        cfg = cfg.map_values(lambda x: Boolean(x > 0), Boolean)
    if alg == "earley":
        from genlm.grammar.parse.earley import Earley

        self.model = Earley(cfg.prefix_grammar)
    elif alg == "cky":
        from genlm.grammar.parse.cky import CKYLM

        self.model = CKYLM(cfg)
    else:
        raise ValueError(f"unrecognized option {alg}")
    super().__init__(eos=EOS, V=cfg.V)

clear_cache()

Clear any cached computations.

Source code in genlm/grammar/cfglm.py
def clear_cache(self):
    """Clear any cached computations."""
    self.model.clear_cache()

from_string(x, semiring=Boolean, **kwargs) classmethod

Create a BoolCFGLM from a string representation of a grammar.

Parameters:

Name Type Description Default
x str

The grammar string

required
semiring

The semiring for weights (default: Boolean)

Boolean
**kwargs

Additional arguments passed to init

{}

Returns:

Type Description
BoolCFGLM

A new language model

Source code in genlm/grammar/cfglm.py
@classmethod
def from_string(cls, x, semiring=Boolean, **kwargs):
    """Create a BoolCFGLM from a string representation of a grammar.

    Args:
        x (str): The grammar string
        semiring: The semiring for weights (default: Boolean)
        **kwargs: Additional arguments passed to __init__

    Returns:
        (BoolCFGLM): A new language model
    """
    return cls(CFG.from_string(x, semiring), **kwargs)

p_next(context)

Compute next token probabilities given a context.

Parameters:

Name Type Description Default
context sequence

The conditioning context

required

Returns:

Type Description
chart

The next token weights

Raises:

Type Description
AssertionError

If context contains out-of-vocabulary tokens

Source code in genlm/grammar/cfglm.py
def p_next(self, context):
    """Compute next token probabilities given a context.

    Args:
        context (sequence): The conditioning context

    Returns:
        (Float.chart): The next token weights

    Raises:
        AssertionError: If context contains out-of-vocabulary tokens
    """
    assert set(context) <= self.V, f"OOVs detected: {set(context) - self.V}"
    p = self.model.next_token_weights(self.model.chart(context)).trim()
    return Float.chart({w: 1 for w in p})

CFG

Weighted Context-free Grammar

A weighted context-free grammar consists of:

  • R: A semiring that defines the weights

  • S: A start symbol (nonterminal)

  • V: A set of terminal symbols (vocabulary)

  • N: A set of nonterminal symbols

  • rules: A list of weighted production rules

Each rule has the form: w: X -> Y1 Y2 ... Yn where:

  • w is a weight from the semiring R

  • X is a nonterminal symbol

  • Y1...Yn are terminal or nonterminal symbols

Source code in genlm/grammar/cfg.py
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
class CFG:
    """
    Weighted Context-free Grammar

    A weighted context-free grammar consists of:\n
    - `R`: A semiring that defines the weights\n
    - `S`: A start symbol (nonterminal)\n
    - `V`: A set of terminal symbols (vocabulary)\n
    - `N`: A set of nonterminal symbols\n
    - `rules`: A list of weighted production rules\n

    Each rule has the form: w: X -> Y1 Y2 ... Yn where:\n
    - w is a weight from the semiring R\n
    - X is a nonterminal symbol\n
    - Y1...Yn are terminal or nonterminal symbols\n
    """

    def __init__(self, R, S, V):
        """
        Initialize a weighted CFG.

        Args:
            R: The semiring for rule weights
            S: The start symbol (nonterminal)
            V: The set of terminal symbols (vocabulary)
        """
        self.R = R  # semiring
        self.V = V  # alphabet
        self.N = {S}  # nonterminals
        self.S = S  # unique start symbol
        self.rules = []  # rules
        self._trim_cache = [None, None]

    def __repr__(self):
        """Return string representation of the grammar."""
        return "Grammar {\n%s\n}" % "\n".join(f"  {r}" for r in self)

    def _repr_html_(self):
        """Return HTML representation of the grammar for Jupyter notebooks."""
        return f'<pre style="width: fit-content; text-align: left; border: thin solid black; padding: 0.5em;">{self}</pre>'

    @classmethod
    def from_string(
        cls,
        string,
        semiring,
        comment="#",
        start="S",
        is_terminal=lambda x: not x[0].isupper(),
    ):
        """
        Create a CFG from a string representation.

        Args:
            string: The grammar rules as a string
            semiring: The semiring for rule weights
            comment: Comment character to ignore lines (default: '#')
            start: Start symbol (default: 'S')
            is_terminal: Function to identify terminal symbols (default: lowercase first letter)

        Returns:
            A new CFG instance
        """
        V = set()
        cfg = cls(R=semiring, S=start, V=V)
        string = string.replace("->", "→")  # synonym for the arrow
        for line in string.split("\n"):
            line = line.strip()
            if not line or line.startswith(comment):
                continue
            try:
                [(w, lhs, rhs)] = re.findall(r"(.*):\s*(\S+)\s*→\s*(.*)$", line)
                lhs = lhs.strip()
                rhs = rhs.strip().split()
                for x in rhs:
                    if is_terminal(x):
                        V.add(x)
                cfg.add(semiring.from_string(w), lhs, *rhs)
            except ValueError:
                raise ValueError(f"bad input line:\n{line}")  # pylint: disable=W0707
        return cfg

    def __getitem__(self, root):
        """
        Return a grammar that denotes the sublanguage of the nonterminal `root`.

        Args:
            root: The nonterminal to use as the new start symbol

        Returns:
            A new CFG with root as the start symbol
        """
        new = self.spawn(S=root)
        for r in self:
            new.add(r.w, r.head, *r.body)
        return new

    def __len__(self):
        """Return number of rules in the grammar."""
        return len(self.rules)

    def __call__(self, xs):
        """
        Compute the total weight of the sequence xs.

        Args:
            xs: A sequence of terminal symbols

        Returns:
            The total weight of all derivations of xs
        """
        self = self.cnf  # need to do this here because the start symbol might change
        return self._parse_chart(xs)[0, self.S, len(xs)]

    def _parse_chart(self, xs):
        """
        Implements CKY algorithm for evaluating the total weight of the xs sequence.

        Args:
            xs: A sequence of terminal symbols

        Returns:
            A chart containing the weights of all subderivations
        """
        (nullary, terminal, binary) = self._cnf  # will convert to CNF
        N = len(xs)
        # nullary rule
        c = self.R.chart()
        for i in range(N + 1):
            c[i, self.S, i] += nullary
        # preterminal rules
        for i in range(N):
            for r in terminal[xs[i]]:
                c[i, r.head, i + 1] += r.w
        # binary rules
        for span in range(2, N + 1):
            for i in range(N - span + 1):
                k = i + span
                for j in range(i + 1, k):
                    for r in binary:
                        X, [Y, Z] = r.head, r.body
                        c[i, X, k] += r.w * c[i, Y, j] * c[j, Z, k]
        return c

    def language(self, depth):
        """
        Enumerate strings generated by this cfg by derivations up to the given depth.

        Args:
            depth: Maximum derivation depth to consider

        Returns:
            A chart containing the weighted language up to the given depth
        """
        lang = self.R.chart()
        for d in self.derivations(self.S, depth):
            lang[d.Yield()] += d.weight()
        return lang

    @cached_property
    def rhs(self):
        """
        Map from each nonterminal to the list of rules with it as their left-hand side.

        Returns:
            A dict mapping nonterminals to lists of rules
        """
        rhs = defaultdict(list)
        for r in self:
            rhs[r.head].append(r)
        return rhs

    def is_terminal(self, x):
        """Return True if x is a terminal symbol."""
        return x in self.V

    def is_nonterminal(self, X):
        """Return True if X is a nonterminal symbol."""
        return not self.is_terminal(X)

    def __iter__(self):
        """Iterate over the rules in the grammar."""
        return iter(self.rules)

    @property
    def size(self):
        """Return total size of the grammar (sum of rule lengths)."""
        return sum(1 + len(r.body) for r in self)

    @property
    def num_rules(self):
        """Return number of rules in the grammar."""
        return len(self.rules)

    @property
    def expected_length(self):
        """
        Compute the expected length of a string using the Expectation semiring.

        Returns:
            The expected length of strings generated by this grammar

        Raises:
            AssertionError: If grammar is not over the Float semiring
        """
        assert self.R == Float, (
            "This method only supports grammars over the Float semiring"
        )
        new_cfg = self.__class__(R=Expectation, S=self.S, V=self.V)
        for r in self:
            new_cfg.add(
                Expectation(r.w, r.w * sum(self.is_terminal(y) for y in r.body)),
                r.head,
                *r.body,
            )
        return new_cfg.treesum().score[1]

    def spawn(self, *, R=None, S=None, V=None):
        """
        Create an empty grammar with the same R, S, and V.

        Args:
            R: Optional new semiring
            S: Optional new start symbol
            V: Optional new vocabulary

        Returns:
            A new empty CFG with specified parameters
        """
        return self.__class__(
            R=self.R if R is None else R,
            S=self.S if S is None else S,
            V=set(self.V) if V is None else V,
        )

    def add(self, w, head, *body):
        """
        Add a rule of the form w: head -> body1, body2, ... body_k.

        Args:
            w: The rule weight
            head: The left-hand side nonterminal
            *body: The right-hand side symbols

        Returns:
            The added rule, or None if weight is zero
        """
        if w == self.R.zero:
            return  # skip rules with weight zero
        self.N.add(head)
        r = Rule(w, head, body)
        self.rules.append(r)
        return r

    def renumber(self):
        """
        Rename nonterminals to integers.

        Returns:
            A new CFG with integer nonterminals
        """
        i = Integerizer()
        max_v = max((x for x in self.V if isinstance(x, int)), default=0)
        return self.rename(lambda x: i(x) + max_v + 1)

    def rename(self, f):
        """
        Return a new grammar that is the result of applying f to each nonterminal.

        Args:
            f: Function to rename nonterminals

        Returns:
            A new CFG with renamed nonterminals
        """
        new = self.spawn(S=f(self.S))
        for r in self:
            new.add(
                r.w, f(r.head), *((y if self.is_terminal(y) else f(y) for y in r.body))
            )
        return new

    def map_values(self, f, R):
        """
        Return a new grammar that is the result of applying f: self.R -> R to each rule's weight.

        Args:
            f: Function to map weights
            R: New semiring for weights

        Returns:
            A new CFG with mapped weights
        """
        new = self.spawn(R=R)
        for r in self:
            new.add(f(r.w), r.head, *r.body)
        return new

    def assert_equal(self, other, verbose=False, throw=True):
        """
        Assertion for the equality of self and other modulo rule reordering.

        Args:
            other: The grammar to compare against
            verbose: If True, print differences
            throw: If True, raise AssertionError on inequality

        Raises:
            AssertionError: If grammars are not equal and throw=True
        """
        assert verbose or throw
        if isinstance(other, str):
            other = self.__class__.from_string(other, self.R)
        if verbose:
            # TODO: need to check the weights in the print out; we do it in the assertion
            S = set(self.rules)
            G = set(other.rules)
            for r in sorted(S | G, key=str):
                if r in S and r in G:
                    continue
                # if r in S and r not in G: continue
                # if r not in S and r in G: continue
                print(
                    colors.mark(r in S),
                    # colors.mark(r in S and r in G),
                    colors.mark(r in G),
                    r,
                )
        assert not throw or Counter(self.rules) == Counter(other.rules), (
            f"\n\nhave=\n{str(self)}\nwant=\n{str(other)}"
        )

    def treesum(self, **kwargs):
        """
        Total weight of the start symbol.

        Returns:
            The total weight of all derivations from the start symbol
        """
        return self.agenda(**kwargs)[self.S]

    def trim(self, bottomup_only=False):
        """
        Return an equivalent grammar with no dead or useless nonterminals or rules.

        Args:
            bottomup_only: If True, only remove non-generating nonterminals

        Returns:
            A new trimmed CFG
        """
        if self._trim_cache[bottomup_only] is not None:
            return self._trim_cache[bottomup_only]

        C = set(self.V)
        C.update(e.head for e in self.rules if len(e.body) == 0)

        incoming = defaultdict(list)
        outgoing = defaultdict(list)
        for e in self:
            incoming[e.head].append(e)
            for b in e.body:
                outgoing[b].append(e)

        agenda = set(C)
        while agenda:
            x = agenda.pop()
            for e in outgoing[x]:
                if all((b in C) for b in e.body):
                    if e.head not in C:
                        C.add(e.head)
                        agenda.add(e.head)

        if bottomup_only:
            val = self._trim(C)
            self._trim_cache[bottomup_only] = val
            val._trim_cache[bottomup_only] = val
            return val

        T = {self.S}
        agenda.update(T)
        while agenda:
            x = agenda.pop()
            for e in incoming[x]:
                # assert e.head in T
                for b in e.body:
                    if b not in T and b in C:
                        T.add(b)
                        agenda.add(b)

        val = self._trim(T)
        self._trim_cache[bottomup_only] = val
        val._trim_cache[bottomup_only] = val
        return val

    def cotrim(self):
        """
        Trim the grammar so that all nonterminals are generating.

        Returns:
            A new CFG with only generating nonterminals
        """
        return self.trim(bottomup_only=True)

    def _trim(self, symbols):
        """
        Helper method for trim() - creates new grammar with only given symbols.

        Args:
            symbols: Set of symbols to keep

        Returns:
            A new CFG with only rules using the given symbols
        """
        new = self.spawn()
        for p in self:
            if p.head in symbols and p.w != self.R.zero and set(p.body) <= symbols:
                new.add(p.w, p.head, *p.body)
        return new

    # ___________________________________________________________________________
    # Derivation enumeration

    def derivations(self, X, H):
        """
        Enumerate derivations of symbol X with height <= H.

        Args:
            X: The symbol to derive from (default: start symbol)
            H: Maximum derivation height

        Yields:
            Derivation objects representing derivation trees
        """
        if X is None:
            X = self.S
        if self.is_terminal(X):
            yield X
        elif H <= 0:
            return
        else:
            for r in self.rhs[X]:
                for ys in self._derivations_list(r.body, H - 1):
                    yield Derivation(r, X, *ys)

    def _derivations_list(self, Xs, H):
        """
        Helper method for derivations; expands any list of symbols X up to depth H.

        Args:
            Xs: List of symbols to derive
            H: Maximum derivation height

        Yields:
            Tuples of derivations
        """
        if len(Xs) == 0:
            yield ()
        else:
            for x in self.derivations(Xs[0], H):
                for xs in self._derivations_list(Xs[1:], H):
                    yield (x, *xs)

    # ___________________________________________________________________________
    # Transformations

    def _unary_graph(self):
        """
        Compute the matrix closure of unary rules.

        Returns:
            A WeightedGraph representing unary rule closure
        """
        A = WeightedGraph(self.R)
        for r in self:
            if len(r.body) == 1 and self.is_nonterminal(r.body[0]):
                A[r.head, r.body[0]] += r.w
        A.N |= self.N
        return A

    def _unary_graph_transpose(self):
        """
        Compute the matrix closure of unary rules (transposed).

        Returns:
            A WeightedGraph representing transposed unary rule closure
        """
        A = WeightedGraph(self.R)
        for r in self:
            if len(r.body) == 1 and self.is_nonterminal(r.body[0]):
                A[r.body[0], r.head] += r.w
        A.N |= self.N
        return A

    def unaryremove(self):
        """
        Return an equivalent grammar with no unary rules.

        Returns:
            A new CFG without unary rules
        """
        W = self._unary_graph().closure_scc_based()
        # W = self._unary_graph().closure_reference()

        new = self.spawn()
        for r in self:
            if len(r.body) == 1 and self.is_nonterminal(r.body[0]):
                continue
            for Y in self.N:
                new.add(W[Y, r.head] * r.w, Y, *r.body)

        return new

    def has_unary_cycle(self):
        """
        Check if the grammar has unary cycles.

        Returns:
            True if the grammar contains unary cycles
        """
        f = self._unary_graph().buckets
        return any(
            True for r in self if len(r.body) == 1 and f.get(r.head) == f.get(r.body[0])
        )

    def unarycycleremove(self, trim=True):
        """
        Return an equivalent grammar with no unary cycles.

        Args:
            trim: If True, trim the resulting grammar

        Returns:
            A new CFG without unary cycles
        """

        def bot(x):
            return x if x in acyclic else (x, "bot")

        G = self._unary_graph()

        new = self.spawn(S=self.S)

        bucket = G.buckets

        acyclic = set()
        for nodes, _ in G.Blocks:
            if len(nodes) == 1:
                [X] = nodes
                if G[X, X] == self.R.zero:
                    acyclic.add(X)

        # run Lehmann's on each cylical SCC
        for nodes, W in G.Blocks:
            if len(nodes) == 1:
                [X] = nodes
                if X in acyclic:
                    continue

            for X1, X2 in W:
                new.add(W[X1, X2], X1, bot(X2))

        for r in self:
            if len(r.body) == 1 and bucket.get(r.body[0]) == bucket[r.head]:
                continue
            new.add(r.w, bot(r.head), *r.body)

        # TODO: figure out how to ensure that the new grammar is trimmed by
        # construction (assuming the input grammar was trim).
        if trim:
            new = new.trim()

        return new

    def nullaryremove(self, binarize=True, trim=True, **kwargs):
        """
        Return an equivalent grammar with no nullary rules except for one at the start symbol.

        Args:
            binarize: If True, binarize the grammar first
            trim: If True, trim the resulting grammar
            **kwargs: Additional arguments passed to _push_null_weights

        Returns:
            A new CFG without nullary rules (except at start)
        """
        # A really wide rule can take a very long time because of the power set
        # in this rule so it is really important to binarize.
        if binarize:
            self = self.binarize()  # pragma: no cover
        self = self.separate_start()
        tmp = self._push_null_weights(self.null_weight(), **kwargs)
        return tmp.trim() if trim else tmp

    def null_weight(self):
        """
        Compute the map from nonterminal to total weight of generating the empty string.

        Returns:
            A dict mapping nonterminals to their null weights
        """
        ecfg = self.spawn(V=set())
        for p in self:
            if not any(self.is_terminal(y) for y in p.body):
                ecfg.add(p.w, p.head, *p.body)
        return ecfg.agenda()

    def null_weight_start(self):
        """
        Compute the null weight of the start symbol.

        Returns:
            The total weight of generating the empty string from the start symbol
        """
        return self.null_weight()[self.S]

    def _push_null_weights(self, null_weight, rename=NotNull):
        """
        Returns a grammar that generates the same weighted language but is nullary-free
        at all nonterminals except its start symbol.

        Args:
            null_weight: Dict mapping nonterminals to their null weights
            rename: Function to rename nonterminals (default: NotNull)

        Returns:
            A new CFG without nullary rules (except at start)
        """
        # Warning: this method might have issues when `separate_start` hasn't
        # been run before.  So we run it rather than leaving it up to chance.
        assert self.S not in {y for r in self for y in r.body}

        def f(x):
            "Rename nonterminal if necessary"
            if (
                null_weight[x] == self.R.zero or x == self.S
            ):  # not necessary; keep old name
                return x
            else:
                return rename(x)

        rcfg = self.spawn()
        rcfg.add(null_weight[self.S], self.S)

        for r in self:
            if len(r.body) == 0:
                continue  # drop nullary rule

            for B in product([0, 1], repeat=len(r.body)):
                v, new_body = r.w, []

                for i, b in enumerate(B):
                    if b:
                        v *= null_weight[r.body[i]]
                    else:
                        new_body.append(f(r.body[i]))

                # exclude the cases that would be new nullary rules!
                if len(new_body) > 0:
                    rcfg.add(v, f(r.head), *new_body)

        return rcfg

    def separate_start(self):
        """
        Ensure that the start symbol does not appear on the RHS of any rule.

        Returns:
            A new CFG with start symbol only on LHS
        """
        # create a new start symbol if the current one appears on the rhs of any existing rule
        if self.S in {y for r in self for y in r.body}:
            S = _gen_nt(self.S)
            new = self.spawn(S=S)
            # preterminal rules
            new.add(self.R.one, S, self.S)
            for r in self:
                new.add(r.w, r.head, *r.body)
            return new
        else:
            return self

    def separate_terminals(self):
        """
        Ensure that each terminal is produced by a preterminal rule.

        Returns:
            A new CFG with terminals only in preterminal rules
        """
        one = self.R.one
        new = self.spawn()

        _preterminal = {}

        def preterminal(x):
            y = _preterminal.get(x)
            if y is None:
                y = new.add(one, _gen_nt(), x)
                _preterminal[x] = y
            return y

        for r in self:
            if len(r.body) == 1 and self.is_terminal(r.body[0]):
                new.add(r.w, r.head, *r.body)
            else:
                new.add(
                    r.w,
                    r.head,
                    *(
                        (preterminal(y).head if self.is_terminal(y) else y)
                        for y in r.body
                    ),
                )

        return new

    def binarize(self):
        """
        Return an equivalent grammar with arity ≤ 2.

        Returns:
            A new CFG with binary rules
        """
        new = self.spawn()

        stack = list(self)
        while stack:
            p = stack.pop()
            if len(p.body) <= 2:
                new.add(p.w, p.head, *p.body)
            else:
                stack.extend(self._fold(p, [(0, 1)]))

        return new

    def _fold(self, p, I):
        """
        Helper method for binarization - folds a rule into binary rules.

        Args:
            p: The rule to fold
            I: List of (start,end) indices for folding

        Returns:
            List of new binary rules
        """
        # new productions
        P, heads = [], []
        for i, j in I:
            head = _gen_nt()
            heads.append(head)
            body = p.body[i : j + 1]
            P.append(Rule(self.R.one, head, body))

        # new "head" production
        body = tuple()
        start = 0
        for (end, n), head in zip(I, heads):
            body += p.body[start:end] + (head,)
            start = n + 1
        body += p.body[start:]
        P.append(Rule(p.w, p.head, body))

        return P

    @cached_property
    def cnf(self):
        """
        Transform this grammar into Chomsky Normal Form (CNF).

        Returns:
            A new CFG in CNF
        """
        new = (
            self.separate_terminals()
            .nullaryremove(binarize=True)
            .trim()
            .unaryremove()
            .trim()
        )
        assert new.in_cnf(), "\n".join(
            str(r) for r in new._find_invalid_cnf_rule()
        )  # pragma: no cover
        return new

    # TODO: make CNF grammars a speciazed subclass of CFG.
    @cached_property
    def _cnf(self):
        """
        Note: Throws an exception if the grammar is not in CNF.

        Returns:
            Tuple of (nullary weight, terminal rules dict, binary rules list)
        """
        nullary = self.R.zero
        terminal = defaultdict(list)
        binary = []
        for r in self:
            if len(r.body) == 0:
                nullary += r.w
                assert r.head == self.S, [self.S, r]
            elif len(r.body) == 1:
                terminal[r.body[0]].append(r)
                assert self.is_terminal(r.body[0])
            else:
                assert len(r.body) == 2
                binary.append(r)
                assert self.is_nonterminal(r.body[0])
                assert self.is_nonterminal(r.body[1])
        return (nullary, terminal, binary)

    def in_cnf(self):
        """
        Return true if the grammar is in CNF.

        Returns:
            True if grammar is in Chomsky Normal Form
        """
        return len(list(self._find_invalid_cnf_rule())) == 0

    def _find_invalid_cnf_rule(self):
        """
        Return true if the grammar is in CNF.

        Yields:
            Rules that violate CNF
        """
        for r in self:
            assert r.head in self.N
            if len(r.body) == 0 and r.head == self.S:
                continue
            elif len(r.body) == 1 and self.is_terminal(r.body[0]):
                continue
            elif len(r.body) == 2 and all(
                self.is_nonterminal(y) and y != self.S for y in r.body
            ):
                continue
            else:
                yield r

    #    def has_nullary(self):
    #        return any((len(p.body) == 0) for p in self if p.head != self.S)

    def unfold(self, i, k):
        """
        Apply the unfolding transformation to rule i and subgoal k.

        Args:
            i: Index of rule to unfold
            k: Index of subgoal in rule body

        Returns:
            A new CFG with the rule unfolded
        """
        assert isinstance(i, int) and isinstance(k, int)
        s = self.rules[i]
        assert self.is_nonterminal(s.body[k])

        new = self.spawn()
        for j, r in enumerate(self):
            if j != i:
                new.add(r.w, r.head, *r.body)

        for r in self.rhs[s.body[k]]:
            new.add(s.w * r.w, s.head, *s.body[:k], *r.body, *s.body[k + 1 :])

        return new

    def dependency_graph(self):
        """
        Head-to-body dependency graph of the rules of the grammar.

        Returns:
            A WeightedGraph representing dependencies between symbols
        """
        deps = WeightedGraph(Boolean)
        for r in self:
            for y in r.body:
                deps[r.head, y] += Boolean.one
        deps.N |= self.N
        deps.N |= self.V
        return deps

    # TODO: the default treesum algorithm should probably be SCC-decomposed newton's method
    # def agenda(self, tol=1e-12, maxiter=float('inf')):
    def agenda(self, tol=1e-12, maxiter=100_000):
        """
        Agenda-based semi-naive evaluation for treesums.

        Args:
            tol: Convergence tolerance
            maxiter: Maximum iterations

        Returns:
            A chart containing the treesum weights
        """
        old = self.R.chart()

        # precompute the mapping from updates to where they need to go
        routing = defaultdict(list)
        for r in self:
            for k in range(len(r.body)):
                routing[r.body[k]].append((r, k))

        deps = self.dependency_graph()
        blocks = deps.blocks
        bucket = deps.buckets

        # helper function
        def update(x, W):
            change[bucket[x]][x] += W

        change = defaultdict(self.R.chart)
        for a in self.V:
            update(a, self.R.one)

        for r in self:
            if len(r.body) == 0:
                update(r.head, r.w)

        b = len(blocks)
        iteration = 0
        while b >= 0:
            iteration += 1

            # Move on to the next block
            if len(change[b]) == 0 or iteration > maxiter:
                b -= 1
                iteration = 0  # reset iteration number for the next bucket
                continue

            u, v = change[b].popitem()

            new = old[u] + v

            if self.R.metric(old[u], new) <= tol:
                continue

            for r, k in routing[u]:
                W = r.w
                for j in range(len(r.body)):
                    if u == r.body[j]:
                        if j < k:
                            W *= new
                        elif j == k:
                            W *= v
                        else:
                            W *= old[u]
                    else:
                        W *= old[r.body[j]]

                update(r.head, W)

            old[u] = new

        return old

    def naive_bottom_up(self, *, tol=1e-12, timeout=100_000):
        "Naive bottom-up evaluation for treesums; better to use `agenda`."

        def _approx_equal(U, V):
            return all((self.R.metric(U[X], V[X]) <= tol) for X in self.N)

        R = self.R
        V = R.chart()
        counter = 0
        while counter < timeout:
            U = self._bottom_up_step(V)
            if _approx_equal(U, V):
                break
            V = U
            counter += 1
        return V

    def _bottom_up_step(self, V):
        R = self.R
        one = R.one
        U = R.chart()
        for a in self.V:
            U[a] = one
        for p in self:
            update = p.w
            for X in p.body:
                if self.is_nonterminal(X):
                    update *= V[X]
            U[p.head] += update
        return U

    def prefix_weight(self, xs):
        "Total weight of all derivations that have `xs` as a prefix."
        return self.prefix_grammar(xs)

    @cached_property
    def prefix_grammar(self):
        """
        The prefix grammar generates the prefix language of the parent grammar.
        PG[x] = sum_[s in Σ^*] G[xs].
        """
        pg = self.spawn()
        W = self.agenda()

        # Generate a unique ID for this prefix transformation to avoid collisions
        arrow_id = _arrow_nt.next_id()

        pg.S = _gen_nt(self.S)
        pg.add(
            self.R.one, pg.S, _arrow_nt(self.S, arrow_id)
        )  # Attach start symbol to arrow nonterminal
        pg.add(
            W[self.S],
            pg.S,
        )  # The prefix weight of the empty string

        for r in self:
            pg.add(r.w, r.head, *r.body)
            w = self.R.one
            for i in range(
                len(r.body) - 1, -1, -1
            ):  # Add the prefixed rules, with the "arrow" non-terminals on the right.
                if self.is_terminal(r.body[i]):  # For a terminal, a^ := a
                    pg.add(r.w * w, _arrow_nt(r.head, arrow_id), *r.body[:i], r.body[i])
                else:  # Otherwise, the arrow is transmitted downwards on the right spine of the derivation.
                    pg.add(
                        r.w * w,
                        _arrow_nt(r.head, arrow_id),
                        *r.body[:i],
                        _arrow_nt(r.body[i], arrow_id),
                    )
                w = w * W[r.body[i]]
        return pg

    def derivatives(self, s):
        "Return the sequence of derivatives for each prefix of `s`."
        M = len(s)
        D = [self]
        for m in range(M):
            D.append(D[m].derivative(s[m]))
        return D

    # Implementation note: This implementation of the derivative grammar
    # performs nullary elimination at the same time.
    def derivative(self, a, i=0):
        "Return a grammar that generates the derivative with respect to `a`."

        def slash(x, y):
            return Slash(x, y, i=i)

        D = self.spawn(S=slash(self.S, a))
        U = self.null_weight()
        for r in self:
            D.add(r.w, r.head, *r.body)
            delta = self.R.one
            for k, y in enumerate(r.body):
                if slash(r.head, a) in self.N:
                    continue  # SKIP!
                if self.is_terminal(y):
                    if y == a:
                        D.add(delta * r.w, slash(r.head, a), *r.body[k + 1 :])
                else:
                    D.add(
                        delta * r.w,
                        slash(r.head, a),
                        slash(r.body[k], a),
                        *r.body[k + 1 :],
                    )
                delta *= U[y]
        return D

    def _compose_bottom_up_epsilon(self, fst):
        "Determine which items of the composition grammar are supported"

        A = set()

        I = defaultdict(set)  # incomplete items
        C = defaultdict(set)  # complete items
        R = defaultdict(set)  # rules indexed by first subgoal; non-nullary

        special_rules = [Rule(self.R.one, a, (EPSILON, a)) for a in self.V] + [
            Rule(self.R.one, Other(self.S), (self.S,)),
            Rule(self.R.one, Other(self.S), (Other(self.S), EPSILON)),
        ]

        for r in itertools.chain(self, special_rules):
            if len(r.body) > 0:
                R[r.body[0]].add(r)

        # we have two base cases:
        #
        # base case 1: arcs
        for i, (a, _), j, _ in fst.arcs():
            A.add((i, a, (), j))  # empty tuple -> the rule 'complete'

        # base case 2: nullary rules
        for r in self:
            if len(r.body) == 0:
                for i in fst.states:
                    A.add((i, r.head, (), i))

        # drain the agenda
        while A:
            (i, X, Ys, j) = A.pop()

            # No pending items ==> the item is complete
            if not Ys:
                if j in C[i, X]:
                    continue
                C[i, X].add(j)

                # combine the newly completed item with incomplete rules that are
                # looking for an item like this one
                for h, X1, Zs in I[i, X]:
                    A.add((h, X1, Zs[1:], j))

                # initialize rules that can start with an item like this one
                for r in R[X]:
                    A.add((i, r.head, r.body[1:], j))

            # Still have pending items ==> advanced the pending items
            else:
                if (i, X, Ys) in I[j, Ys[0]]:
                    continue
                I[j, Ys[0]].add((i, X, Ys))

                for k in C[j, Ys[0]]:
                    A.add((i, X, Ys[1:], k))

        return C

    def __matmul__(self, fst):
        "Return a CFG denoting the pointwise product or composition of `self` and `fs`."

        # coerce something sequence like into a diagonal FST
        if isinstance(fst, (str, tuple)):
            fst = FST.from_string(fst, self.R)
        # coerce something FSA-like into an FST, might throw an error
        if not isinstance(fst, FST):
            fst = fst.to_fst()

        # Initialize the new CFG:
        # - its start symbol is chosen arbitrarily to be `self.S`
        # - its the alphabet changes - it is now 'output' alphabet of the transducer
        new_start = self.S
        new = self.spawn(S=new_start, V=fst.B - {EPSILON})

        # The bottom-up intersection algorithm is a two-pass algorithm
        #
        # Pass 1: Determine the set of items that are possiblly nonzero-valued
        C = self._compose_bottom_up_epsilon(fst)

        special_rules = [Rule(self.R.one, a, (EPSILON, a)) for a in self.V] + [
            Rule(self.R.one, Other(self.S), (self.S,)),
            Rule(self.R.one, Other(self.S), (Other(self.S), EPSILON)),
        ]

        def join(start, Ys):
            """
            Helper method; expands the rule body

            Given Ys = [Y_1, ... Y_K], we will enumerate expansion of the form

            (s_0, Y_1, s_1), (s_1, Y_2, s_2), ..., (s_{k-1}, Y_K, s_K)

            where each (s_k, Y_k, s_k) in the expansion is a completed items
            (i.e., \forall k: (s_k, Y_k, s_k) in C).
            """
            if not Ys:
                yield []
            else:
                for K in C[start, Ys[0]]:
                    for rest in join(K, Ys[1:]):
                        yield [(start, Ys[0], K)] + rest

        start = {I for (I, _) in C}

        for r in itertools.chain(self, special_rules):
            if len(r.body) == 0:
                for s in fst.states:
                    new.add(r.w, (s, r.head, s))
            else:
                for I in start:
                    for rhs in join(I, r.body):
                        K = rhs[-1][-1]
                        new.add(r.w, (I, r.head, K), *rhs)

        for i, wi in fst.start.items():
            for k, wf in fst.stop.items():
                new.add(wi * wf, new_start, (i, Other(self.S), k))

        for i, (a, b), j, w in fst.arcs():
            if b == EPSILON:
                new.add(w, (i, a, j))
            else:
                new.add(w, (i, a, j), b)
        return new

    def truncate_length(self, max_length):
        "Transform this grammar so that it only generates strings with length ≤ `max_length`."
        from genlm.grammar import WFSA

        m = WFSA(self.R)
        m.add_I(0, self.R.one)
        m.add_F(0, self.R.one)
        for t in range(max_length):
            for x in self.V:
                m.add_arc(t, x, t + 1, self.R.one)
            m.add_F(t + 1, self.R.one)
        return self @ m

    def materialize(self, max_length):
        "Return a `Chart` with this grammar's weighted language for strings ≤ `max_length`."
        return self.cnf.language(max_length).filter(lambda x: len(x) <= max_length)

    def to_bytes(self):
        """Convert terminal symbols from strings to bytes representation.

        This method creates a new grammar where all terminal string symbols are
        converted to their UTF-8 byte representation. Non-terminal symbols are
        preserved as-is.

        Returns:
            CFG: A new grammar with byte terminal symbols

        Raises:
            ValueError: If a terminal symbol is not a string
        """
        new = self.spawn(S=self.S, R=self.R, V=set())

        for r in self:
            new_body = []
            for x in r.body:
                if self.is_terminal(x):
                    if not isinstance(x, str):
                        raise ValueError(f"unsupported terminal type: {type(x)}")
                    bs = list(x.encode("utf-8"))
                    for b in bs:
                        new.V.add(b)
                    new_body.extend(bs)
                else:
                    new_body.append(x)
            new.add(r.w, r.head, *new_body)

        return new

cnf cached property

Transform this grammar into Chomsky Normal Form (CNF).

Returns:

Type Description

A new CFG in CNF

expected_length property

Compute the expected length of a string using the Expectation semiring.

Returns:

Type Description

The expected length of strings generated by this grammar

Raises:

Type Description
AssertionError

If grammar is not over the Float semiring

num_rules property

Return number of rules in the grammar.

prefix_grammar cached property

The prefix grammar generates the prefix language of the parent grammar. PG[x] = sum_[s in Σ^*] G[xs].

rhs cached property

Map from each nonterminal to the list of rules with it as their left-hand side.

Returns:

Type Description

A dict mapping nonterminals to lists of rules

size property

Return total size of the grammar (sum of rule lengths).

__call__(xs)

Compute the total weight of the sequence xs.

Parameters:

Name Type Description Default
xs

A sequence of terminal symbols

required

Returns:

Type Description

The total weight of all derivations of xs

Source code in genlm/grammar/cfg.py
def __call__(self, xs):
    """
    Compute the total weight of the sequence xs.

    Args:
        xs: A sequence of terminal symbols

    Returns:
        The total weight of all derivations of xs
    """
    self = self.cnf  # need to do this here because the start symbol might change
    return self._parse_chart(xs)[0, self.S, len(xs)]

__getitem__(root)

Return a grammar that denotes the sublanguage of the nonterminal root.

Parameters:

Name Type Description Default
root

The nonterminal to use as the new start symbol

required

Returns:

Type Description

A new CFG with root as the start symbol

Source code in genlm/grammar/cfg.py
def __getitem__(self, root):
    """
    Return a grammar that denotes the sublanguage of the nonterminal `root`.

    Args:
        root: The nonterminal to use as the new start symbol

    Returns:
        A new CFG with root as the start symbol
    """
    new = self.spawn(S=root)
    for r in self:
        new.add(r.w, r.head, *r.body)
    return new

__init__(R, S, V)

Initialize a weighted CFG.

Parameters:

Name Type Description Default
R

The semiring for rule weights

required
S

The start symbol (nonterminal)

required
V

The set of terminal symbols (vocabulary)

required
Source code in genlm/grammar/cfg.py
def __init__(self, R, S, V):
    """
    Initialize a weighted CFG.

    Args:
        R: The semiring for rule weights
        S: The start symbol (nonterminal)
        V: The set of terminal symbols (vocabulary)
    """
    self.R = R  # semiring
    self.V = V  # alphabet
    self.N = {S}  # nonterminals
    self.S = S  # unique start symbol
    self.rules = []  # rules
    self._trim_cache = [None, None]

__iter__()

Iterate over the rules in the grammar.

Source code in genlm/grammar/cfg.py
def __iter__(self):
    """Iterate over the rules in the grammar."""
    return iter(self.rules)

__len__()

Return number of rules in the grammar.

Source code in genlm/grammar/cfg.py
def __len__(self):
    """Return number of rules in the grammar."""
    return len(self.rules)

__matmul__(fst)

Return a CFG denoting the pointwise product or composition of self and fs.

Source code in genlm/grammar/cfg.py
def __matmul__(self, fst):
    "Return a CFG denoting the pointwise product or composition of `self` and `fs`."

    # coerce something sequence like into a diagonal FST
    if isinstance(fst, (str, tuple)):
        fst = FST.from_string(fst, self.R)
    # coerce something FSA-like into an FST, might throw an error
    if not isinstance(fst, FST):
        fst = fst.to_fst()

    # Initialize the new CFG:
    # - its start symbol is chosen arbitrarily to be `self.S`
    # - its the alphabet changes - it is now 'output' alphabet of the transducer
    new_start = self.S
    new = self.spawn(S=new_start, V=fst.B - {EPSILON})

    # The bottom-up intersection algorithm is a two-pass algorithm
    #
    # Pass 1: Determine the set of items that are possiblly nonzero-valued
    C = self._compose_bottom_up_epsilon(fst)

    special_rules = [Rule(self.R.one, a, (EPSILON, a)) for a in self.V] + [
        Rule(self.R.one, Other(self.S), (self.S,)),
        Rule(self.R.one, Other(self.S), (Other(self.S), EPSILON)),
    ]

    def join(start, Ys):
        """
        Helper method; expands the rule body

        Given Ys = [Y_1, ... Y_K], we will enumerate expansion of the form

        (s_0, Y_1, s_1), (s_1, Y_2, s_2), ..., (s_{k-1}, Y_K, s_K)

        where each (s_k, Y_k, s_k) in the expansion is a completed items
        (i.e., \forall k: (s_k, Y_k, s_k) in C).
        """
        if not Ys:
            yield []
        else:
            for K in C[start, Ys[0]]:
                for rest in join(K, Ys[1:]):
                    yield [(start, Ys[0], K)] + rest

    start = {I for (I, _) in C}

    for r in itertools.chain(self, special_rules):
        if len(r.body) == 0:
            for s in fst.states:
                new.add(r.w, (s, r.head, s))
        else:
            for I in start:
                for rhs in join(I, r.body):
                    K = rhs[-1][-1]
                    new.add(r.w, (I, r.head, K), *rhs)

    for i, wi in fst.start.items():
        for k, wf in fst.stop.items():
            new.add(wi * wf, new_start, (i, Other(self.S), k))

    for i, (a, b), j, w in fst.arcs():
        if b == EPSILON:
            new.add(w, (i, a, j))
        else:
            new.add(w, (i, a, j), b)
    return new

__repr__()

Return string representation of the grammar.

Source code in genlm/grammar/cfg.py
def __repr__(self):
    """Return string representation of the grammar."""
    return "Grammar {\n%s\n}" % "\n".join(f"  {r}" for r in self)

add(w, head, *body)

Add a rule of the form w: head -> body1, body2, ... body_k.

Parameters:

Name Type Description Default
w

The rule weight

required
head

The left-hand side nonterminal

required
*body

The right-hand side symbols

()

Returns:

Type Description

The added rule, or None if weight is zero

Source code in genlm/grammar/cfg.py
def add(self, w, head, *body):
    """
    Add a rule of the form w: head -> body1, body2, ... body_k.

    Args:
        w: The rule weight
        head: The left-hand side nonterminal
        *body: The right-hand side symbols

    Returns:
        The added rule, or None if weight is zero
    """
    if w == self.R.zero:
        return  # skip rules with weight zero
    self.N.add(head)
    r = Rule(w, head, body)
    self.rules.append(r)
    return r

agenda(tol=1e-12, maxiter=100000)

Agenda-based semi-naive evaluation for treesums.

Parameters:

Name Type Description Default
tol

Convergence tolerance

1e-12
maxiter

Maximum iterations

100000

Returns:

Type Description

A chart containing the treesum weights

Source code in genlm/grammar/cfg.py
def agenda(self, tol=1e-12, maxiter=100_000):
    """
    Agenda-based semi-naive evaluation for treesums.

    Args:
        tol: Convergence tolerance
        maxiter: Maximum iterations

    Returns:
        A chart containing the treesum weights
    """
    old = self.R.chart()

    # precompute the mapping from updates to where they need to go
    routing = defaultdict(list)
    for r in self:
        for k in range(len(r.body)):
            routing[r.body[k]].append((r, k))

    deps = self.dependency_graph()
    blocks = deps.blocks
    bucket = deps.buckets

    # helper function
    def update(x, W):
        change[bucket[x]][x] += W

    change = defaultdict(self.R.chart)
    for a in self.V:
        update(a, self.R.one)

    for r in self:
        if len(r.body) == 0:
            update(r.head, r.w)

    b = len(blocks)
    iteration = 0
    while b >= 0:
        iteration += 1

        # Move on to the next block
        if len(change[b]) == 0 or iteration > maxiter:
            b -= 1
            iteration = 0  # reset iteration number for the next bucket
            continue

        u, v = change[b].popitem()

        new = old[u] + v

        if self.R.metric(old[u], new) <= tol:
            continue

        for r, k in routing[u]:
            W = r.w
            for j in range(len(r.body)):
                if u == r.body[j]:
                    if j < k:
                        W *= new
                    elif j == k:
                        W *= v
                    else:
                        W *= old[u]
                else:
                    W *= old[r.body[j]]

            update(r.head, W)

        old[u] = new

    return old

assert_equal(other, verbose=False, throw=True)

Assertion for the equality of self and other modulo rule reordering.

Parameters:

Name Type Description Default
other

The grammar to compare against

required
verbose

If True, print differences

False
throw

If True, raise AssertionError on inequality

True

Raises:

Type Description
AssertionError

If grammars are not equal and throw=True

Source code in genlm/grammar/cfg.py
def assert_equal(self, other, verbose=False, throw=True):
    """
    Assertion for the equality of self and other modulo rule reordering.

    Args:
        other: The grammar to compare against
        verbose: If True, print differences
        throw: If True, raise AssertionError on inequality

    Raises:
        AssertionError: If grammars are not equal and throw=True
    """
    assert verbose or throw
    if isinstance(other, str):
        other = self.__class__.from_string(other, self.R)
    if verbose:
        # TODO: need to check the weights in the print out; we do it in the assertion
        S = set(self.rules)
        G = set(other.rules)
        for r in sorted(S | G, key=str):
            if r in S and r in G:
                continue
            # if r in S and r not in G: continue
            # if r not in S and r in G: continue
            print(
                colors.mark(r in S),
                # colors.mark(r in S and r in G),
                colors.mark(r in G),
                r,
            )
    assert not throw or Counter(self.rules) == Counter(other.rules), (
        f"\n\nhave=\n{str(self)}\nwant=\n{str(other)}"
    )

binarize()

Return an equivalent grammar with arity ≤ 2.

Returns:

Type Description

A new CFG with binary rules

Source code in genlm/grammar/cfg.py
def binarize(self):
    """
    Return an equivalent grammar with arity ≤ 2.

    Returns:
        A new CFG with binary rules
    """
    new = self.spawn()

    stack = list(self)
    while stack:
        p = stack.pop()
        if len(p.body) <= 2:
            new.add(p.w, p.head, *p.body)
        else:
            stack.extend(self._fold(p, [(0, 1)]))

    return new

cotrim()

Trim the grammar so that all nonterminals are generating.

Returns:

Type Description

A new CFG with only generating nonterminals

Source code in genlm/grammar/cfg.py
def cotrim(self):
    """
    Trim the grammar so that all nonterminals are generating.

    Returns:
        A new CFG with only generating nonterminals
    """
    return self.trim(bottomup_only=True)

dependency_graph()

Head-to-body dependency graph of the rules of the grammar.

Returns:

Type Description

A WeightedGraph representing dependencies between symbols

Source code in genlm/grammar/cfg.py
def dependency_graph(self):
    """
    Head-to-body dependency graph of the rules of the grammar.

    Returns:
        A WeightedGraph representing dependencies between symbols
    """
    deps = WeightedGraph(Boolean)
    for r in self:
        for y in r.body:
            deps[r.head, y] += Boolean.one
    deps.N |= self.N
    deps.N |= self.V
    return deps

derivations(X, H)

Enumerate derivations of symbol X with height <= H.

Parameters:

Name Type Description Default
X

The symbol to derive from (default: start symbol)

required
H

Maximum derivation height

required

Yields:

Type Description

Derivation objects representing derivation trees

Source code in genlm/grammar/cfg.py
def derivations(self, X, H):
    """
    Enumerate derivations of symbol X with height <= H.

    Args:
        X: The symbol to derive from (default: start symbol)
        H: Maximum derivation height

    Yields:
        Derivation objects representing derivation trees
    """
    if X is None:
        X = self.S
    if self.is_terminal(X):
        yield X
    elif H <= 0:
        return
    else:
        for r in self.rhs[X]:
            for ys in self._derivations_list(r.body, H - 1):
                yield Derivation(r, X, *ys)

derivative(a, i=0)

Return a grammar that generates the derivative with respect to a.

Source code in genlm/grammar/cfg.py
def derivative(self, a, i=0):
    "Return a grammar that generates the derivative with respect to `a`."

    def slash(x, y):
        return Slash(x, y, i=i)

    D = self.spawn(S=slash(self.S, a))
    U = self.null_weight()
    for r in self:
        D.add(r.w, r.head, *r.body)
        delta = self.R.one
        for k, y in enumerate(r.body):
            if slash(r.head, a) in self.N:
                continue  # SKIP!
            if self.is_terminal(y):
                if y == a:
                    D.add(delta * r.w, slash(r.head, a), *r.body[k + 1 :])
            else:
                D.add(
                    delta * r.w,
                    slash(r.head, a),
                    slash(r.body[k], a),
                    *r.body[k + 1 :],
                )
            delta *= U[y]
    return D

derivatives(s)

Return the sequence of derivatives for each prefix of s.

Source code in genlm/grammar/cfg.py
def derivatives(self, s):
    "Return the sequence of derivatives for each prefix of `s`."
    M = len(s)
    D = [self]
    for m in range(M):
        D.append(D[m].derivative(s[m]))
    return D

from_string(string, semiring, comment='#', start='S', is_terminal=lambda x: not x[0].isupper()) classmethod

Create a CFG from a string representation.

Parameters:

Name Type Description Default
string

The grammar rules as a string

required
semiring

The semiring for rule weights

required
comment

Comment character to ignore lines (default: '#')

'#'
start

Start symbol (default: 'S')

'S'
is_terminal

Function to identify terminal symbols (default: lowercase first letter)

lambda x: not isupper()

Returns:

Type Description

A new CFG instance

Source code in genlm/grammar/cfg.py
@classmethod
def from_string(
    cls,
    string,
    semiring,
    comment="#",
    start="S",
    is_terminal=lambda x: not x[0].isupper(),
):
    """
    Create a CFG from a string representation.

    Args:
        string: The grammar rules as a string
        semiring: The semiring for rule weights
        comment: Comment character to ignore lines (default: '#')
        start: Start symbol (default: 'S')
        is_terminal: Function to identify terminal symbols (default: lowercase first letter)

    Returns:
        A new CFG instance
    """
    V = set()
    cfg = cls(R=semiring, S=start, V=V)
    string = string.replace("->", "→")  # synonym for the arrow
    for line in string.split("\n"):
        line = line.strip()
        if not line or line.startswith(comment):
            continue
        try:
            [(w, lhs, rhs)] = re.findall(r"(.*):\s*(\S+)\s*→\s*(.*)$", line)
            lhs = lhs.strip()
            rhs = rhs.strip().split()
            for x in rhs:
                if is_terminal(x):
                    V.add(x)
            cfg.add(semiring.from_string(w), lhs, *rhs)
        except ValueError:
            raise ValueError(f"bad input line:\n{line}")  # pylint: disable=W0707
    return cfg

has_unary_cycle()

Check if the grammar has unary cycles.

Returns:

Type Description

True if the grammar contains unary cycles

Source code in genlm/grammar/cfg.py
def has_unary_cycle(self):
    """
    Check if the grammar has unary cycles.

    Returns:
        True if the grammar contains unary cycles
    """
    f = self._unary_graph().buckets
    return any(
        True for r in self if len(r.body) == 1 and f.get(r.head) == f.get(r.body[0])
    )

in_cnf()

Return true if the grammar is in CNF.

Returns:

Type Description

True if grammar is in Chomsky Normal Form

Source code in genlm/grammar/cfg.py
def in_cnf(self):
    """
    Return true if the grammar is in CNF.

    Returns:
        True if grammar is in Chomsky Normal Form
    """
    return len(list(self._find_invalid_cnf_rule())) == 0

is_nonterminal(X)

Return True if X is a nonterminal symbol.

Source code in genlm/grammar/cfg.py
def is_nonterminal(self, X):
    """Return True if X is a nonterminal symbol."""
    return not self.is_terminal(X)

is_terminal(x)

Return True if x is a terminal symbol.

Source code in genlm/grammar/cfg.py
def is_terminal(self, x):
    """Return True if x is a terminal symbol."""
    return x in self.V

language(depth)

Enumerate strings generated by this cfg by derivations up to the given depth.

Parameters:

Name Type Description Default
depth

Maximum derivation depth to consider

required

Returns:

Type Description

A chart containing the weighted language up to the given depth

Source code in genlm/grammar/cfg.py
def language(self, depth):
    """
    Enumerate strings generated by this cfg by derivations up to the given depth.

    Args:
        depth: Maximum derivation depth to consider

    Returns:
        A chart containing the weighted language up to the given depth
    """
    lang = self.R.chart()
    for d in self.derivations(self.S, depth):
        lang[d.Yield()] += d.weight()
    return lang

map_values(f, R)

Return a new grammar that is the result of applying f: self.R -> R to each rule's weight.

Parameters:

Name Type Description Default
f

Function to map weights

required
R

New semiring for weights

required

Returns:

Type Description

A new CFG with mapped weights

Source code in genlm/grammar/cfg.py
def map_values(self, f, R):
    """
    Return a new grammar that is the result of applying f: self.R -> R to each rule's weight.

    Args:
        f: Function to map weights
        R: New semiring for weights

    Returns:
        A new CFG with mapped weights
    """
    new = self.spawn(R=R)
    for r in self:
        new.add(f(r.w), r.head, *r.body)
    return new

materialize(max_length)

Return a Chart with this grammar's weighted language for strings ≤ max_length.

Source code in genlm/grammar/cfg.py
def materialize(self, max_length):
    "Return a `Chart` with this grammar's weighted language for strings ≤ `max_length`."
    return self.cnf.language(max_length).filter(lambda x: len(x) <= max_length)

naive_bottom_up(*, tol=1e-12, timeout=100000)

Naive bottom-up evaluation for treesums; better to use agenda.

Source code in genlm/grammar/cfg.py
def naive_bottom_up(self, *, tol=1e-12, timeout=100_000):
    "Naive bottom-up evaluation for treesums; better to use `agenda`."

    def _approx_equal(U, V):
        return all((self.R.metric(U[X], V[X]) <= tol) for X in self.N)

    R = self.R
    V = R.chart()
    counter = 0
    while counter < timeout:
        U = self._bottom_up_step(V)
        if _approx_equal(U, V):
            break
        V = U
        counter += 1
    return V

null_weight()

Compute the map from nonterminal to total weight of generating the empty string.

Returns:

Type Description

A dict mapping nonterminals to their null weights

Source code in genlm/grammar/cfg.py
def null_weight(self):
    """
    Compute the map from nonterminal to total weight of generating the empty string.

    Returns:
        A dict mapping nonterminals to their null weights
    """
    ecfg = self.spawn(V=set())
    for p in self:
        if not any(self.is_terminal(y) for y in p.body):
            ecfg.add(p.w, p.head, *p.body)
    return ecfg.agenda()

null_weight_start()

Compute the null weight of the start symbol.

Returns:

Type Description

The total weight of generating the empty string from the start symbol

Source code in genlm/grammar/cfg.py
def null_weight_start(self):
    """
    Compute the null weight of the start symbol.

    Returns:
        The total weight of generating the empty string from the start symbol
    """
    return self.null_weight()[self.S]

nullaryremove(binarize=True, trim=True, **kwargs)

Return an equivalent grammar with no nullary rules except for one at the start symbol.

Parameters:

Name Type Description Default
binarize

If True, binarize the grammar first

True
trim

If True, trim the resulting grammar

True
**kwargs

Additional arguments passed to _push_null_weights

{}

Returns:

Type Description

A new CFG without nullary rules (except at start)

Source code in genlm/grammar/cfg.py
def nullaryremove(self, binarize=True, trim=True, **kwargs):
    """
    Return an equivalent grammar with no nullary rules except for one at the start symbol.

    Args:
        binarize: If True, binarize the grammar first
        trim: If True, trim the resulting grammar
        **kwargs: Additional arguments passed to _push_null_weights

    Returns:
        A new CFG without nullary rules (except at start)
    """
    # A really wide rule can take a very long time because of the power set
    # in this rule so it is really important to binarize.
    if binarize:
        self = self.binarize()  # pragma: no cover
    self = self.separate_start()
    tmp = self._push_null_weights(self.null_weight(), **kwargs)
    return tmp.trim() if trim else tmp

prefix_weight(xs)

Total weight of all derivations that have xs as a prefix.

Source code in genlm/grammar/cfg.py
def prefix_weight(self, xs):
    "Total weight of all derivations that have `xs` as a prefix."
    return self.prefix_grammar(xs)

rename(f)

Return a new grammar that is the result of applying f to each nonterminal.

Parameters:

Name Type Description Default
f

Function to rename nonterminals

required

Returns:

Type Description

A new CFG with renamed nonterminals

Source code in genlm/grammar/cfg.py
def rename(self, f):
    """
    Return a new grammar that is the result of applying f to each nonterminal.

    Args:
        f: Function to rename nonterminals

    Returns:
        A new CFG with renamed nonterminals
    """
    new = self.spawn(S=f(self.S))
    for r in self:
        new.add(
            r.w, f(r.head), *((y if self.is_terminal(y) else f(y) for y in r.body))
        )
    return new

renumber()

Rename nonterminals to integers.

Returns:

Type Description

A new CFG with integer nonterminals

Source code in genlm/grammar/cfg.py
def renumber(self):
    """
    Rename nonterminals to integers.

    Returns:
        A new CFG with integer nonterminals
    """
    i = Integerizer()
    max_v = max((x for x in self.V if isinstance(x, int)), default=0)
    return self.rename(lambda x: i(x) + max_v + 1)

separate_start()

Ensure that the start symbol does not appear on the RHS of any rule.

Returns:

Type Description

A new CFG with start symbol only on LHS

Source code in genlm/grammar/cfg.py
def separate_start(self):
    """
    Ensure that the start symbol does not appear on the RHS of any rule.

    Returns:
        A new CFG with start symbol only on LHS
    """
    # create a new start symbol if the current one appears on the rhs of any existing rule
    if self.S in {y for r in self for y in r.body}:
        S = _gen_nt(self.S)
        new = self.spawn(S=S)
        # preterminal rules
        new.add(self.R.one, S, self.S)
        for r in self:
            new.add(r.w, r.head, *r.body)
        return new
    else:
        return self

separate_terminals()

Ensure that each terminal is produced by a preterminal rule.

Returns:

Type Description

A new CFG with terminals only in preterminal rules

Source code in genlm/grammar/cfg.py
def separate_terminals(self):
    """
    Ensure that each terminal is produced by a preterminal rule.

    Returns:
        A new CFG with terminals only in preterminal rules
    """
    one = self.R.one
    new = self.spawn()

    _preterminal = {}

    def preterminal(x):
        y = _preterminal.get(x)
        if y is None:
            y = new.add(one, _gen_nt(), x)
            _preterminal[x] = y
        return y

    for r in self:
        if len(r.body) == 1 and self.is_terminal(r.body[0]):
            new.add(r.w, r.head, *r.body)
        else:
            new.add(
                r.w,
                r.head,
                *(
                    (preterminal(y).head if self.is_terminal(y) else y)
                    for y in r.body
                ),
            )

    return new

spawn(*, R=None, S=None, V=None)

Create an empty grammar with the same R, S, and V.

Parameters:

Name Type Description Default
R

Optional new semiring

None
S

Optional new start symbol

None
V

Optional new vocabulary

None

Returns:

Type Description

A new empty CFG with specified parameters

Source code in genlm/grammar/cfg.py
def spawn(self, *, R=None, S=None, V=None):
    """
    Create an empty grammar with the same R, S, and V.

    Args:
        R: Optional new semiring
        S: Optional new start symbol
        V: Optional new vocabulary

    Returns:
        A new empty CFG with specified parameters
    """
    return self.__class__(
        R=self.R if R is None else R,
        S=self.S if S is None else S,
        V=set(self.V) if V is None else V,
    )

to_bytes()

Convert terminal symbols from strings to bytes representation.

This method creates a new grammar where all terminal string symbols are converted to their UTF-8 byte representation. Non-terminal symbols are preserved as-is.

Returns:

Name Type Description
CFG

A new grammar with byte terminal symbols

Raises:

Type Description
ValueError

If a terminal symbol is not a string

Source code in genlm/grammar/cfg.py
def to_bytes(self):
    """Convert terminal symbols from strings to bytes representation.

    This method creates a new grammar where all terminal string symbols are
    converted to their UTF-8 byte representation. Non-terminal symbols are
    preserved as-is.

    Returns:
        CFG: A new grammar with byte terminal symbols

    Raises:
        ValueError: If a terminal symbol is not a string
    """
    new = self.spawn(S=self.S, R=self.R, V=set())

    for r in self:
        new_body = []
        for x in r.body:
            if self.is_terminal(x):
                if not isinstance(x, str):
                    raise ValueError(f"unsupported terminal type: {type(x)}")
                bs = list(x.encode("utf-8"))
                for b in bs:
                    new.V.add(b)
                new_body.extend(bs)
            else:
                new_body.append(x)
        new.add(r.w, r.head, *new_body)

    return new

treesum(**kwargs)

Total weight of the start symbol.

Returns:

Type Description

The total weight of all derivations from the start symbol

Source code in genlm/grammar/cfg.py
def treesum(self, **kwargs):
    """
    Total weight of the start symbol.

    Returns:
        The total weight of all derivations from the start symbol
    """
    return self.agenda(**kwargs)[self.S]

trim(bottomup_only=False)

Return an equivalent grammar with no dead or useless nonterminals or rules.

Parameters:

Name Type Description Default
bottomup_only

If True, only remove non-generating nonterminals

False

Returns:

Type Description

A new trimmed CFG

Source code in genlm/grammar/cfg.py
def trim(self, bottomup_only=False):
    """
    Return an equivalent grammar with no dead or useless nonterminals or rules.

    Args:
        bottomup_only: If True, only remove non-generating nonterminals

    Returns:
        A new trimmed CFG
    """
    if self._trim_cache[bottomup_only] is not None:
        return self._trim_cache[bottomup_only]

    C = set(self.V)
    C.update(e.head for e in self.rules if len(e.body) == 0)

    incoming = defaultdict(list)
    outgoing = defaultdict(list)
    for e in self:
        incoming[e.head].append(e)
        for b in e.body:
            outgoing[b].append(e)

    agenda = set(C)
    while agenda:
        x = agenda.pop()
        for e in outgoing[x]:
            if all((b in C) for b in e.body):
                if e.head not in C:
                    C.add(e.head)
                    agenda.add(e.head)

    if bottomup_only:
        val = self._trim(C)
        self._trim_cache[bottomup_only] = val
        val._trim_cache[bottomup_only] = val
        return val

    T = {self.S}
    agenda.update(T)
    while agenda:
        x = agenda.pop()
        for e in incoming[x]:
            # assert e.head in T
            for b in e.body:
                if b not in T and b in C:
                    T.add(b)
                    agenda.add(b)

    val = self._trim(T)
    self._trim_cache[bottomup_only] = val
    val._trim_cache[bottomup_only] = val
    return val

truncate_length(max_length)

Transform this grammar so that it only generates strings with length ≤ max_length.

Source code in genlm/grammar/cfg.py
def truncate_length(self, max_length):
    "Transform this grammar so that it only generates strings with length ≤ `max_length`."
    from genlm.grammar import WFSA

    m = WFSA(self.R)
    m.add_I(0, self.R.one)
    m.add_F(0, self.R.one)
    for t in range(max_length):
        for x in self.V:
            m.add_arc(t, x, t + 1, self.R.one)
        m.add_F(t + 1, self.R.one)
    return self @ m

unarycycleremove(trim=True)

Return an equivalent grammar with no unary cycles.

Parameters:

Name Type Description Default
trim

If True, trim the resulting grammar

True

Returns:

Type Description

A new CFG without unary cycles

Source code in genlm/grammar/cfg.py
def unarycycleremove(self, trim=True):
    """
    Return an equivalent grammar with no unary cycles.

    Args:
        trim: If True, trim the resulting grammar

    Returns:
        A new CFG without unary cycles
    """

    def bot(x):
        return x if x in acyclic else (x, "bot")

    G = self._unary_graph()

    new = self.spawn(S=self.S)

    bucket = G.buckets

    acyclic = set()
    for nodes, _ in G.Blocks:
        if len(nodes) == 1:
            [X] = nodes
            if G[X, X] == self.R.zero:
                acyclic.add(X)

    # run Lehmann's on each cylical SCC
    for nodes, W in G.Blocks:
        if len(nodes) == 1:
            [X] = nodes
            if X in acyclic:
                continue

        for X1, X2 in W:
            new.add(W[X1, X2], X1, bot(X2))

    for r in self:
        if len(r.body) == 1 and bucket.get(r.body[0]) == bucket[r.head]:
            continue
        new.add(r.w, bot(r.head), *r.body)

    # TODO: figure out how to ensure that the new grammar is trimmed by
    # construction (assuming the input grammar was trim).
    if trim:
        new = new.trim()

    return new

unaryremove()

Return an equivalent grammar with no unary rules.

Returns:

Type Description

A new CFG without unary rules

Source code in genlm/grammar/cfg.py
def unaryremove(self):
    """
    Return an equivalent grammar with no unary rules.

    Returns:
        A new CFG without unary rules
    """
    W = self._unary_graph().closure_scc_based()
    # W = self._unary_graph().closure_reference()

    new = self.spawn()
    for r in self:
        if len(r.body) == 1 and self.is_nonterminal(r.body[0]):
            continue
        for Y in self.N:
            new.add(W[Y, r.head] * r.w, Y, *r.body)

    return new

unfold(i, k)

Apply the unfolding transformation to rule i and subgoal k.

Parameters:

Name Type Description Default
i

Index of rule to unfold

required
k

Index of subgoal in rule body

required

Returns:

Type Description

A new CFG with the rule unfolded

Source code in genlm/grammar/cfg.py
def unfold(self, i, k):
    """
    Apply the unfolding transformation to rule i and subgoal k.

    Args:
        i: Index of rule to unfold
        k: Index of subgoal in rule body

    Returns:
        A new CFG with the rule unfolded
    """
    assert isinstance(i, int) and isinstance(k, int)
    s = self.rules[i]
    assert self.is_nonterminal(s.body[k])

    new = self.spawn()
    for j, r in enumerate(self):
        if j != i:
            new.add(r.w, r.head, *r.body)

    for r in self.rhs[s.body[k]]:
        new.add(s.w * r.w, s.head, *s.body[:k], *r.body, *s.body[k + 1 :])

    return new

Chart

Bases: dict

A weighted chart data structure that extends dict with semiring operations.

The Chart class provides methods for semiring operations like addition and multiplication, as well as utilities for filtering, comparing, and manipulating weighted values.

Attributes:

Name Type Description
semiring

The semiring that defines the weight operations

Source code in genlm/grammar/chart.py
class Chart(dict):
    """A weighted chart data structure that extends dict with semiring operations.

    The Chart class provides methods for semiring operations like addition and multiplication,
    as well as utilities for filtering, comparing, and manipulating weighted values.

    Attributes:
        semiring: The semiring that defines the weight operations
    """

    def __init__(self, semiring, vals=()):
        """Initialize a Chart.

        Args:
            semiring: The semiring for weight operations
            vals: Optional initial values for the chart
        """
        self.semiring = semiring
        super().__init__(vals)

    def __missing__(self, k):
        """Return zero weight for missing keys."""
        return self.semiring.zero

    def spawn(self):
        """Create a new empty Chart with the same semiring."""
        return Chart(self.semiring)

    def __add__(self, other):
        """Add two charts element-wise.

        Args:
            other: Another Chart to add to this one

        Returns:
            A new Chart containing the element-wise sum
        """
        new = self.spawn()
        for k, v in self.items():
            new[k] += v
        for k, v in other.items():
            new[k] += v
        return new

    def __mul__(self, other):
        """Multiply two charts element-wise.

        Args:
            other: Another Chart to multiply with this one

        Returns:
            A new Chart containing the element-wise product
        """
        new = self.spawn()
        for k in self:
            v = self[k] * other[k]
            if v == self.semiring.zero:
                continue
            new[k] += v
        return new

    def product(self, ks):
        """Compute the product of values for the given keys.

        Args:
            ks: Sequence of keys to multiply values for

        Returns:
            The product of values for the given keys
        """
        v = self.semiring.one
        for k in ks:
            v *= self[k]
        return v

    def copy(self):
        """Create a shallow copy of this Chart."""
        return Chart(self.semiring, self)

    def trim(self):
        """Return a new Chart with zero-weight entries removed."""
        return Chart(
            self.semiring, {k: v for k, v in self.items() if v != self.semiring.zero}
        )

    def metric(self, other):
        """Compute the maximum distance between this Chart and another.

        Args:
            other: Another Chart to compare against

        Returns:
            The maximum semiring metric between corresponding values
        """
        assert isinstance(other, Chart)
        err = 0
        for x in self.keys() | other.keys():
            err = max(err, self.semiring.metric(self[x], other[x]))
        return err

    def _repr_html_(self):
        """Return HTML representation for Jupyter notebooks."""
        return (
            '<div style="font-family: Monospace;">'
            + format_table(self.trim().items(), headings=["key", "value"])
            + "</div>"
        )

    def __repr__(self):
        """Return string representation, excluding zero weights."""
        return repr({k: v for k, v in self.items() if v != self.semiring.zero})

    def __str__(self, style_value=lambda k, v: str(v)):
        """Return formatted string representation.

        Args:
            style_value: Optional function to format values

        Returns:
            Formatted string showing non-zero entries
        """

        def key(k):
            return -self.semiring.metric(self[k], self.semiring.zero)

        return (
            "Chart {\n"
            + "\n".join(
                f"  {k!r}: {style_value(k, self[k])},"
                for k in sorted(self, key=key)
                if self[k] != self.semiring.zero
            )
            + "\n}"
        )

    def assert_equal(self, want, *, domain=None, tol=1e-5, verbose=False, throw=True):
        """Assert that this Chart equals another within tolerance.

        Args:
            want: The expected Chart or dict of values
            domain: Optional set of keys to check
            tol: Tolerance for floating point comparisons
            verbose: Whether to print detailed comparison
            throw: Whether to raise AssertionError on mismatch
        """
        if not isinstance(want, Chart):
            want = self.semiring.chart(want)
        if domain is None:
            domain = self.keys() | want.keys()
        assert verbose or throw
        errors = []
        for x in domain:
            if self.semiring.metric(self[x], want[x]) <= tol:
                if verbose:
                    print(colors.mark(True), x, self[x])
            else:
                if verbose:
                    print(colors.mark(False), x, self[x], want[x])
                errors.append(x)
        if throw:
            for x in errors:
                raise AssertionError(f"{x}: {self[x]} {want[x]}")

    def argmax(self):
        """Return the key with maximum value."""
        return max(self, key=self.__getitem__)

    def argmin(self):
        """Return the key with minimum value."""
        return min(self, key=self.__getitem__)

    def top(self, k):
        """Return a new Chart with the k largest values.

        Args:
            k: Number of top values to keep

        Returns:
            A new Chart containing only the k largest values
        """
        return Chart(
            self.semiring,
            {k: self[k] for k in sorted(self, key=self.__getitem__, reverse=True)[:k]},
        )

    def max(self):
        """Return the maximum value in the Chart."""
        return max(self.values())

    def min(self):
        """Return the minimum value in the Chart."""
        return min(self.values())

    def sum(self):
        """Return the sum of all values in the Chart."""
        return sum(self.values())

    def sort(self, **kwargs):
        """Return a new Chart with entries sorted by key.

        Args:
            **kwargs: Arguments passed to sorted()

        Returns:
            A new Chart with sorted entries
        """
        return self.semiring.chart((k, self[k]) for k in sorted(self, **kwargs))

    def sort_descending(self):
        """Return a new Chart with entries sorted by decreasing value."""
        return self.semiring.chart(
            (k, self[k]) for k in sorted(self, key=lambda k: -self[k])
        )

    def normalize(self):
        """Return a new Chart with values normalized to sum to 1."""
        Z = self.sum()
        if Z == 0:
            return self
        return self.semiring.chart((k, v / Z) for k, v in self.items())

    def filter(self, f):
        """Return a new Chart keeping only entries where f(key) is True.

        Args:
            f: Predicate function that takes a key and returns bool

        Returns:
            A new Chart containing only entries where f(key) is True
        """
        return self.semiring.chart((k, v) for k, v in self.items() if f(k))

    def project(self, f):
        """Apply a function to keys, summing weights when transformed keys overlap.

        Args:
            f: Function to transform keys

        Returns:
            A new Chart with transformed keys and summed weights
        """
        out = self.semiring.chart()
        for k, v in self.items():
            out[f(k)] += v
        return out

    # TODO: the more general version of this method is join
    def compare(self, other, *, domain=None):
        """Compare this Chart to another using pandas DataFrame.

        Args:
            other: Another Chart or dict to compare against
            domain: Optional set of keys to compare

        Returns:
            pandas DataFrame showing key-by-key comparison
        """
        import pandas as pd

        if not isinstance(other, Chart):
            other = self.semiring.chart(other)
        if domain is None:
            domain = self.keys() | other.keys()
        rows = []
        for x in domain:
            m = self.semiring.metric(self[x], other[x])
            rows.append(dict(key=x, self=self[x], other=other[x], metric=m))
        return pd.DataFrame(rows)

__add__(other)

Add two charts element-wise.

Parameters:

Name Type Description Default
other

Another Chart to add to this one

required

Returns:

Type Description

A new Chart containing the element-wise sum

Source code in genlm/grammar/chart.py
def __add__(self, other):
    """Add two charts element-wise.

    Args:
        other: Another Chart to add to this one

    Returns:
        A new Chart containing the element-wise sum
    """
    new = self.spawn()
    for k, v in self.items():
        new[k] += v
    for k, v in other.items():
        new[k] += v
    return new

__init__(semiring, vals=())

Initialize a Chart.

Parameters:

Name Type Description Default
semiring

The semiring for weight operations

required
vals

Optional initial values for the chart

()
Source code in genlm/grammar/chart.py
def __init__(self, semiring, vals=()):
    """Initialize a Chart.

    Args:
        semiring: The semiring for weight operations
        vals: Optional initial values for the chart
    """
    self.semiring = semiring
    super().__init__(vals)

__missing__(k)

Return zero weight for missing keys.

Source code in genlm/grammar/chart.py
def __missing__(self, k):
    """Return zero weight for missing keys."""
    return self.semiring.zero

__mul__(other)

Multiply two charts element-wise.

Parameters:

Name Type Description Default
other

Another Chart to multiply with this one

required

Returns:

Type Description

A new Chart containing the element-wise product

Source code in genlm/grammar/chart.py
def __mul__(self, other):
    """Multiply two charts element-wise.

    Args:
        other: Another Chart to multiply with this one

    Returns:
        A new Chart containing the element-wise product
    """
    new = self.spawn()
    for k in self:
        v = self[k] * other[k]
        if v == self.semiring.zero:
            continue
        new[k] += v
    return new

__repr__()

Return string representation, excluding zero weights.

Source code in genlm/grammar/chart.py
def __repr__(self):
    """Return string representation, excluding zero weights."""
    return repr({k: v for k, v in self.items() if v != self.semiring.zero})

__str__(style_value=lambda k, v: str(v))

Return formatted string representation.

Parameters:

Name Type Description Default
style_value

Optional function to format values

lambda k, v: str(v)

Returns:

Type Description

Formatted string showing non-zero entries

Source code in genlm/grammar/chart.py
def __str__(self, style_value=lambda k, v: str(v)):
    """Return formatted string representation.

    Args:
        style_value: Optional function to format values

    Returns:
        Formatted string showing non-zero entries
    """

    def key(k):
        return -self.semiring.metric(self[k], self.semiring.zero)

    return (
        "Chart {\n"
        + "\n".join(
            f"  {k!r}: {style_value(k, self[k])},"
            for k in sorted(self, key=key)
            if self[k] != self.semiring.zero
        )
        + "\n}"
    )

argmax()

Return the key with maximum value.

Source code in genlm/grammar/chart.py
def argmax(self):
    """Return the key with maximum value."""
    return max(self, key=self.__getitem__)

argmin()

Return the key with minimum value.

Source code in genlm/grammar/chart.py
def argmin(self):
    """Return the key with minimum value."""
    return min(self, key=self.__getitem__)

assert_equal(want, *, domain=None, tol=1e-05, verbose=False, throw=True)

Assert that this Chart equals another within tolerance.

Parameters:

Name Type Description Default
want

The expected Chart or dict of values

required
domain

Optional set of keys to check

None
tol

Tolerance for floating point comparisons

1e-05
verbose

Whether to print detailed comparison

False
throw

Whether to raise AssertionError on mismatch

True
Source code in genlm/grammar/chart.py
def assert_equal(self, want, *, domain=None, tol=1e-5, verbose=False, throw=True):
    """Assert that this Chart equals another within tolerance.

    Args:
        want: The expected Chart or dict of values
        domain: Optional set of keys to check
        tol: Tolerance for floating point comparisons
        verbose: Whether to print detailed comparison
        throw: Whether to raise AssertionError on mismatch
    """
    if not isinstance(want, Chart):
        want = self.semiring.chart(want)
    if domain is None:
        domain = self.keys() | want.keys()
    assert verbose or throw
    errors = []
    for x in domain:
        if self.semiring.metric(self[x], want[x]) <= tol:
            if verbose:
                print(colors.mark(True), x, self[x])
        else:
            if verbose:
                print(colors.mark(False), x, self[x], want[x])
            errors.append(x)
    if throw:
        for x in errors:
            raise AssertionError(f"{x}: {self[x]} {want[x]}")

compare(other, *, domain=None)

Compare this Chart to another using pandas DataFrame.

Parameters:

Name Type Description Default
other

Another Chart or dict to compare against

required
domain

Optional set of keys to compare

None

Returns:

Type Description

pandas DataFrame showing key-by-key comparison

Source code in genlm/grammar/chart.py
def compare(self, other, *, domain=None):
    """Compare this Chart to another using pandas DataFrame.

    Args:
        other: Another Chart or dict to compare against
        domain: Optional set of keys to compare

    Returns:
        pandas DataFrame showing key-by-key comparison
    """
    import pandas as pd

    if not isinstance(other, Chart):
        other = self.semiring.chart(other)
    if domain is None:
        domain = self.keys() | other.keys()
    rows = []
    for x in domain:
        m = self.semiring.metric(self[x], other[x])
        rows.append(dict(key=x, self=self[x], other=other[x], metric=m))
    return pd.DataFrame(rows)

copy()

Create a shallow copy of this Chart.

Source code in genlm/grammar/chart.py
def copy(self):
    """Create a shallow copy of this Chart."""
    return Chart(self.semiring, self)

filter(f)

Return a new Chart keeping only entries where f(key) is True.

Parameters:

Name Type Description Default
f

Predicate function that takes a key and returns bool

required

Returns:

Type Description

A new Chart containing only entries where f(key) is True

Source code in genlm/grammar/chart.py
def filter(self, f):
    """Return a new Chart keeping only entries where f(key) is True.

    Args:
        f: Predicate function that takes a key and returns bool

    Returns:
        A new Chart containing only entries where f(key) is True
    """
    return self.semiring.chart((k, v) for k, v in self.items() if f(k))

max()

Return the maximum value in the Chart.

Source code in genlm/grammar/chart.py
def max(self):
    """Return the maximum value in the Chart."""
    return max(self.values())

metric(other)

Compute the maximum distance between this Chart and another.

Parameters:

Name Type Description Default
other

Another Chart to compare against

required

Returns:

Type Description

The maximum semiring metric between corresponding values

Source code in genlm/grammar/chart.py
def metric(self, other):
    """Compute the maximum distance between this Chart and another.

    Args:
        other: Another Chart to compare against

    Returns:
        The maximum semiring metric between corresponding values
    """
    assert isinstance(other, Chart)
    err = 0
    for x in self.keys() | other.keys():
        err = max(err, self.semiring.metric(self[x], other[x]))
    return err

min()

Return the minimum value in the Chart.

Source code in genlm/grammar/chart.py
def min(self):
    """Return the minimum value in the Chart."""
    return min(self.values())

normalize()

Return a new Chart with values normalized to sum to 1.

Source code in genlm/grammar/chart.py
def normalize(self):
    """Return a new Chart with values normalized to sum to 1."""
    Z = self.sum()
    if Z == 0:
        return self
    return self.semiring.chart((k, v / Z) for k, v in self.items())

product(ks)

Compute the product of values for the given keys.

Parameters:

Name Type Description Default
ks

Sequence of keys to multiply values for

required

Returns:

Type Description

The product of values for the given keys

Source code in genlm/grammar/chart.py
def product(self, ks):
    """Compute the product of values for the given keys.

    Args:
        ks: Sequence of keys to multiply values for

    Returns:
        The product of values for the given keys
    """
    v = self.semiring.one
    for k in ks:
        v *= self[k]
    return v

project(f)

Apply a function to keys, summing weights when transformed keys overlap.

Parameters:

Name Type Description Default
f

Function to transform keys

required

Returns:

Type Description

A new Chart with transformed keys and summed weights

Source code in genlm/grammar/chart.py
def project(self, f):
    """Apply a function to keys, summing weights when transformed keys overlap.

    Args:
        f: Function to transform keys

    Returns:
        A new Chart with transformed keys and summed weights
    """
    out = self.semiring.chart()
    for k, v in self.items():
        out[f(k)] += v
    return out

sort(**kwargs)

Return a new Chart with entries sorted by key.

Parameters:

Name Type Description Default
**kwargs

Arguments passed to sorted()

{}

Returns:

Type Description

A new Chart with sorted entries

Source code in genlm/grammar/chart.py
def sort(self, **kwargs):
    """Return a new Chart with entries sorted by key.

    Args:
        **kwargs: Arguments passed to sorted()

    Returns:
        A new Chart with sorted entries
    """
    return self.semiring.chart((k, self[k]) for k in sorted(self, **kwargs))

sort_descending()

Return a new Chart with entries sorted by decreasing value.

Source code in genlm/grammar/chart.py
def sort_descending(self):
    """Return a new Chart with entries sorted by decreasing value."""
    return self.semiring.chart(
        (k, self[k]) for k in sorted(self, key=lambda k: -self[k])
    )

spawn()

Create a new empty Chart with the same semiring.

Source code in genlm/grammar/chart.py
def spawn(self):
    """Create a new empty Chart with the same semiring."""
    return Chart(self.semiring)

sum()

Return the sum of all values in the Chart.

Source code in genlm/grammar/chart.py
def sum(self):
    """Return the sum of all values in the Chart."""
    return sum(self.values())

top(k)

Return a new Chart with the k largest values.

Parameters:

Name Type Description Default
k

Number of top values to keep

required

Returns:

Type Description

A new Chart containing only the k largest values

Source code in genlm/grammar/chart.py
def top(self, k):
    """Return a new Chart with the k largest values.

    Args:
        k: Number of top values to keep

    Returns:
        A new Chart containing only the k largest values
    """
    return Chart(
        self.semiring,
        {k: self[k] for k in sorted(self, key=self.__getitem__, reverse=True)[:k]},
    )

trim()

Return a new Chart with zero-weight entries removed.

Source code in genlm/grammar/chart.py
def trim(self):
    """Return a new Chart with zero-weight entries removed."""
    return Chart(
        self.semiring, {k: v for k, v in self.items() if v != self.semiring.zero}
    )

Earley

Implements a semiring-weighted version Earley's algorithm that runs in $\mathcal{O}(N^3|G|)$ time. Note that nullary rules and unary chain cycles will be been removed, altering the set of derivation trees.

Source code in genlm/grammar/parse/earley.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
class Earley:
    """
    Implements a semiring-weighted version Earley's algorithm that runs in $\mathcal{O}(N^3|G|)$ time.
    Note that nullary rules and unary chain cycles will be been removed, altering the
    set of derivation trees.
    """

    __slots__ = (
        "cfg",
        "order",
        "_chart",
        "V",
        "eos",
        "_initial_column",
        "R_outgoing",
        "rhs",
        "ORDER_MAX",
        "intern_Ys",
        "unit_Ys",
        "first_Ys",
        "rest_Ys",
    )

    def __init__(self, cfg):
        cfg = cfg.nullaryremove(binarize=True).unarycycleremove().renumber()
        self.cfg = cfg

        # cache of chart columns
        self._chart = {}

        # Topological ordering on the grammar symbols so that we process unary
        # rules in a topological order.
        self.order = cfg._unary_graph_transpose().buckets

        self.ORDER_MAX = 1 + max(self.order.values())

        # left-corner graph
        R_outgoing = defaultdict(set)
        for r in cfg:
            if len(r.body) == 0:
                continue
            A = r.head
            B = r.body[0]
            if cfg.is_terminal(B):
                continue
            R_outgoing[A].add(B)
        self.R_outgoing = R_outgoing

        # Integerize rule right-hand side states
        intern_Ys = Integerizer()
        assert intern_Ys(()) == 0

        for r in self.cfg:
            for p in range(len(r.body) + 1):
                intern_Ys.add(r.body[p:])

        self.intern_Ys = intern_Ys

        self.rhs = {}
        for X in self.cfg.N:
            self.rhs[X] = []
            for r in self.cfg.rhs[X]:
                if r.body == ():
                    continue
                self.rhs[X].append((r.w, intern_Ys(r.body)))

        self.first_Ys = np.zeros(len(intern_Ys), dtype=object)
        self.rest_Ys = np.zeros(len(intern_Ys), dtype=int)
        self.unit_Ys = np.zeros(len(intern_Ys), dtype=int)

        for Ys, code in list(self.intern_Ys.items()):
            self.unit_Ys[code] = len(Ys) == 1
            if len(Ys) > 0:
                self.first_Ys[code] = Ys[0]
                self.rest_Ys[code] = intern_Ys(Ys[1:])

        # self.generate_rust_test_case()

        col = Column(0)
        self.PREDICT(col)
        self._initial_column = col

    def clear_cache(self):
        self._chart.clear()

    def __call__(self, x):
        N = len(x)

        # return if empty string
        if N == 0:
            return sum(r.w for r in self.cfg.rhs[self.cfg.S] if r.body == ())

        # initialize bookkeeping structures
        self._chart[()] = [self._initial_column]

        cols = self.chart(x)

        value = cols[N].c_chart.get((0, self.cfg.S))
        return value if value is not None else self.cfg.R.zero

    def chart(self, x):
        x = tuple(x)
        c = self._chart.get(x)
        if c is None:
            self._chart[x] = c = self._compute_chart(x)
        return c

    def _compute_chart(self, x):
        if len(x) == 0:
            return [self._initial_column]
        else:
            chart = self.chart(x[:-1])
            last_chart = self.next_column(chart, x[-1])
            return chart + [
                last_chart
            ]  # TODO: avoid list addition here as it is not constant time!

    def next_column(self, prev_cols, token):
        prev_col = prev_cols[-1]
        next_col = Column(prev_cols[-1].k + 1)
        next_col_c_chart = next_col.c_chart
        prev_col_i_chart = prev_col.i_chart

        rest_Ys = self.rest_Ys
        _update = self._update

        Q = LocatorMaxHeap()

        # SCAN: phrase(I, X/Ys, K) += phrase(I, X/[Y|Ys], J) * word(J, Y, K)
        for item in prev_col.waiting_for[token]:
            (I, X, Ys) = item
            _update(next_col, Q, I, X, rest_Ys[Ys], prev_col_i_chart[item])

        # ATTACH: phrase(I, X/Ys, K) += phrase(I, X/[Y|Ys], J) * phrase(J, Y/[], K)
        while Q:
            jy = Q.pop()[0]
            (J, Y) = jy

            col_J = prev_cols[J]
            col_J_i_chart = col_J.i_chart
            y = next_col_c_chart[jy]
            for customer in col_J.waiting_for[Y]:
                (I, X, Ys) = customer
                _update(next_col, Q, I, X, rest_Ys[Ys], col_J_i_chart[customer] * y)

        self.PREDICT(next_col)

        return next_col

    def PREDICT(self, col):
        # PREDICT: phrase(K, X/Ys, K) += rule(X -> Ys) with some filtering heuristics
        k = col.k

        # Filtering heuristic: Don't create the predicted item (K, X, [...], K)
        # unless there exists an item that wants the X item that it may
        # eventually provide.  In other words, for predicting this item to be
        # useful there must be an item of the form (I, X', [X, ...], K) in this
        # column for which lc(X', X) is true.
        if col.k == 0:
            agenda = [self.cfg.S]
        else:
            agenda = list(col.waiting_for)

        outgoing = self.R_outgoing

        reachable = set(agenda)

        while agenda:
            X = agenda.pop()
            for Y in outgoing[X]:
                if Y not in reachable:
                    reachable.add(Y)
                    agenda.append(Y)

        rhs = self.rhs
        _update = self._update
        for X in reachable:
            for w, Ys in rhs.get(X, ()):
                _update(col, None, k, X, Ys, w)

    def _update(self, col, Q, I, X, Ys, value):
        K = col.k
        if Ys == 0:
            # Items of the form phrase(I, X/[], K)
            item = (I, X)
            was = col.c_chart.get(item)
            if was is None:
                Q[item] = -((K - I) * self.ORDER_MAX + self.order[X])
                col.c_chart[item] = value
            else:
                col.c_chart[item] = was + value

        else:
            # Items of the form phrase(I, X/[Y|Ys], K)
            item = (I, X, Ys)
            was = col.i_chart.get(item)
            if was is None:
                col.waiting_for[self.first_Ys[Ys]].append(item)
                col.i_chart[item] = value
            else:
                col.i_chart[item] = was + value

    # We have derived the `next_token_weights` algorithm by backpropagation on
    # the program with respect to the item `phrase(0, s, K)`.
    #
    # ATTACH: phrase(I, X/Ys, K) += phrase(I, X/[Y|Ys], J) * phrase(J, Y/[], K)
    #
    # Directly applying the gradient transformation, we get
    #
    # ∇phrase(0, s/[], K) += 1
    # ∇phrase(J, Y/[], K) += phrase(I, X/[Y|Ys], J) * ∇phrase(I, X/Ys, K)
    #
    # Some quick analysis reveals that the `Ys` list must always be [], and
    # that K is always equal to the final column.  We specialize the program
    # below:
    #
    # ∇phrase(0, s/[], K) += 1
    # ∇phrase(J, Y/[], K) += phrase(I, X/[Y], J) * ∇phrase(I, X/[], K)
    #
    # We can abbreviate the names:
    #
    # q(0, s) += 1
    # q(J, Y) += phrase(I, X/[Y], J) * q(I, X)
    #
    # These items satisfy (I > J) and (X > Y) where the latter is the
    # nonterminal ordering.  Thus, we can efficiently evaluate these equations
    # by backward chaining.
    #
    # The final output is the vector
    #
    # p(W) += q(I, X) * phrase(I, X/[W], J)  where len(J) * terminal(W).
    #
    def next_token_weights(self, cols):
        is_terminal = self.cfg.is_terminal
        zero = self.cfg.R.zero

        q = {}
        q[0, self.cfg.S] = self.cfg.R.one

        col = cols[-1]
        col_waiting_for = col.waiting_for
        col_i_chart = col.i_chart

        # SCAN: phrase(I, X/Ys, K) += phrase(I, X/[Y|Ys], J) * word(J, Y, K)
        p = self.cfg.R.chart()

        for Y in col_waiting_for:
            if is_terminal(Y):
                total = zero
                for I, X, Ys in col_waiting_for[Y]:
                    if self.unit_Ys[Ys]:
                        node = (I, X)
                        value = self._helper(node, cols, q)
                        total += col_i_chart[I, X, Ys] * value
                p[Y] = total

        return p

    def _helper(self, top, cols, q):
        value = q.get(top)
        if value is not None:
            return value

        zero = self.cfg.R.zero
        stack = [Node(top, None, zero)]

        while stack:
            node = stack[-1]  # 👀

            # place neighbors above the node on the stack
            (J, Y) = node.node

            t = node.cursor

            if node.edges is None:
                node.edges = [x for x in cols[J].waiting_for[Y] if self.unit_Ys[x[2]]]

            # cursor is at the end, all neighbors are done
            elif t == len(node.edges):
                # clear the node from the stack
                stack.pop()
                # promote the incomplete value node.value to a complete value (q)
                q[node.node] = node.value

            else:
                (I, X, _) = arc = node.edges[t]
                neighbor = (I, X)
                neighbor_value = q.get(neighbor)
                if neighbor_value is None:
                    stack.append(Node(neighbor, None, zero))
                else:
                    # neighbor value is ready, advance the cursor, add the
                    # neighbors contribution to the nodes value
                    node.cursor += 1
                    node.value += cols[J].i_chart[arc] * neighbor_value

        return q[top]

    def generate_rust_test_case(self):
        # generates a test case in Rust code by exporting the parser state variables
        # Copy-paste the printout to `mod tests { ... }` in lib.rs to debug.

        print(
            """
    #[test]
    fn test_earley() {{

        let rhs: HashMap<u32, Vec<RHS>> = [
            {}
        ].iter().cloned().collect();
        """.format(
                ", ".join(
                    f"({x}, "
                    + "vec![{}])".format(", ".join(f"({float(u)}, {v})" for u, v in y))
                    for x, y in self.rhs.items()
                )
            )
        )

        print(
            """
        let order: HashMap<u32, u32> = [
            {}
        ].iter().cloned().collect();
        """.format(", ".join(f"({u}, {v})" for u, v in self.order.items()))
        )

        print(
            """
        let outgoing: HashMap<u32, Vec<u32>> = [
            {}
        ].iter().cloned().collect();
        """.format(
                ", ".join(
                    "({}, vec![{}])".format(i, ", ".join(map(str, s)))
                    for i, s in self.R_outgoing.items()
                )
            )
        )

        print(
            """
        let first_ys = vec![
            {}
        ].iter().cloned().collect();
        """.format(
                ", ".join(
                    f'Terminal(String::from("{y}"))'
                    if isinstance(y, str)
                    else f"Nonterminal({y})"
                    for y in self.first_Ys
                )
            )
        )

        print(
            """
        let rest_ys = vec![
            {}
        ];
        """.format(", ".join(map(str, self.rest_Ys)))
        )

        print(
            """
        let unit_ys = vec![
            {}
        ];
        """.format(", ".join(map(lambda x: str(bool(x)).lower(), self.unit_Ys)))
        )

        print(
            """
        let vocab = [
            {}
        ].iter().cloned().collect();
        """.format(", ".join(f'String::from("{v}")' for v in self.cfg.V))
        )

        print(
            """
        let empty_weight = {};
        let start = {};
        let order_max = {};
        """.format(
                sum(r.w for r in self.cfg.rhs[self.cfg.S] if r.body == ()),
                self.cfg.S,
                self.ORDER_MAX,
            )
        )

        print("""
        let mut earley = Earley::new(
            rhs, start, order, order_max, outgoing, first_ys,
            rest_ys, unit_ys, vocab, empty_weight,
        );
        let chart = earley.p_next(vec![]);
        dbg!(&chart);

    }}
        """)

FST

Bases: WFSA

A weighted finite-state transducer that maps between two alphabets.

A finite-state transducer (FST) extends a weighted finite-state automaton (WFSA) by having two alphabets - an input alphabet A and an output alphabet B. Each transition is labeled with a pair (a,b) where a is from A and b is from B.

The FST defines a weighted relation between strings over A and strings over B.

Source code in genlm/grammar/fst.py
class FST(WFSA):
    """A weighted finite-state transducer that maps between two alphabets.

    A finite-state transducer (FST) extends a weighted finite-state automaton (WFSA)
    by having two alphabets - an input alphabet A and an output alphabet B. Each transition
    is labeled with a pair (a,b) where a is from A and b is from B.

    The FST defines a weighted relation between strings over A and strings over B.
    """

    def __init__(self, R):
        """Initialize an empty FST.

        Args:
            R: The semiring for transition weights
        """
        super().__init__(R=R)

        # alphabets
        self.A = set()  # input alphabet
        self.B = set()  # output alphabet

    def add_arc(self, i, ab, j, w):  # pylint: disable=arguments-renamed
        """Add a weighted transition between states.

        Args:
            i: Source state
            ab: Tuple (a,b) of input/output symbols, or EPSILON
            j: Target state
            w: Weight of the transition

        Returns:
            self
        """
        if ab != EPSILON:
            (a, b) = ab
            self.A.add(a)
            self.B.add(b)
        return super().add_arc(i, ab, j, w)

    def set_arc(self, i, ab, j, w):  # pylint: disable=arguments-renamed
        """Set the weight of a transition between states.

        Args:
            i: Source state
            ab: Tuple (a,b) of input/output symbols, or EPSILON
            j: Target state
            w: New weight for the transition

        Returns:
            self
        """
        if ab != EPSILON:
            (a, b) = ab
            self.A.add(a)
            self.B.add(b)
        return super().set_arc(i, ab, j, w)

    def __call__(self, x, y):
        """Compute the weight of mapping input x to output y.

        If x or y is None, returns a weighted language representing the cross section.

        Args:
            x: Input string or None
            y: Output string or None

        Returns:
            Weight of mapping x to y, or a WFSA representing the cross section if x or y is None
        """
        if x is not None and y is not None:
            x = FST.from_string(x, self.R)
            y = FST.from_string(y, self.R)
            return (x @ self @ y).total_weight()

        elif x is not None and y is None:
            x = FST.from_string(x, self.R)
            return (x @ self).project(1)

        elif x is None and y is not None:
            y = FST.from_string(y, self.R)
            return (self @ y).project(0)

        else:
            return self

    @classmethod
    def from_string(cls, xs, R, w=None):
        """Create an FST that accepts only the given string with optional weight.

        Args:
            xs: Input string
            R: Semiring for weights
            w: Optional weight for the string

        Returns:
            An FST accepting only xs with weight w
        """
        return cls.diag(WFSA.from_string(xs=xs, R=R, w=w))

    @staticmethod
    def from_pairs(pairs, R):
        """Create an FST accepting the given input-output string pairs.

        Args:
            pairs: List of (input_string, output_string) tuples
            R: Semiring for weights

        Returns:
            An FST accepting the given string pairs with weight one
        """
        p = FST(R)
        p.add_I(0, R.one)
        p.add_F(1, R.one)
        for i, (xs, ys) in enumerate(pairs):
            p.add_arc(0, EPSILON, (i, 0), R.one)
            for j, (x, y) in enumerate(zip_longest(xs, ys, fillvalue=EPSILON)):
                p.add_arc((i, j), (x, y), (i, j + 1), R.one)
            p.add_arc((i, max(len(xs), len(ys))), EPSILON, 1, R.one)
        return p

    def project(self, axis):
        """Project the FST onto one of its components to create a WFSA.

        Args:
            axis: 0 for input projection, 1 for output projection

        Returns:
            A WFSA over the projected alphabet
        """
        assert axis in [0, 1]
        A = WFSA(R=self.R)
        for i, (a, b), j, w in self.arcs():
            if axis == 0:
                A.add_arc(i, a, j, w)
            else:
                A.add_arc(i, b, j, w)
        for i, w in self.I:
            A.add_I(i, w)
        for i, w in self.F:
            A.add_F(i, w)
        return A

    @cached_property
    def T(self):
        """Return the transpose of this FST by swapping input/output labels.

        Returns:
            A new FST with input/output labels swapped
        """
        T = self.spawn()
        for i, (a, b), j, w in self.arcs():
            T.add_arc(i, (b, a), j, w)  # (a,b) -> (b,a)
        for q, w in self.I:
            T.add_I(q, w)
        for q, w in self.F:
            T.add_F(q, w)
        return T

    def prune_to_alphabet(self, A, B):
        """Remove transitions with labels not in the given alphabets.

        Args:
            A: Set of allowed input symbols, or None to allow all
            B: Set of allowed output symbols, or None to allow all

        Returns:
            A new FST with invalid transitions removed
        """
        T = self.spawn()
        for i, (a, b), j, w in self.arcs():
            if (A is None or a in A) and (B is None or b in B):
                T.add_arc(i, (a, b), j, w)
        for q, w in self.I:
            T.add_I(q, w)
        for q, w in self.F:
            T.add_F(q, w)
        return T.trim

    def __matmul__(self, other):
        """Compose this FST with another FST or automaton.

        Args:
            other: Another FST, CFG or automaton to compose with

        Returns:
            The composed FST
        """
        if not isinstance(other, FST):
            from genlm.grammar.cfg import CFG

            if isinstance(other, CFG):
                return other @ self.T
            else:
                other = FST.diag(other)

        # minor efficiency trick: it's slightly more efficient to associate the composition as follows
        if len(self.states) < len(other.states):
            return (
                self._augment_epsilon_transitions(0)  # rename epsilons on the right
                ._compose(
                    epsilon_filter_fst(self.R, self.B), coarsen=False
                )  # this FST carefully combines the special epsilons
                ._compose(
                    other._augment_epsilon_transitions(1)
                )  # rename epsilons on th left
            )

        else:
            return self._augment_epsilon_transitions(
                0
            )._compose(  # rename epsilons on the right
                epsilon_filter_fst(
                    self.R, self.B
                )._compose(  # this FST carefully combines the special epsilons
                    other._augment_epsilon_transitions(1), coarsen=False
                )
            )  # rename epsilons on th left

    def _compose(self, other, coarsen=True):
        """Internal composition implementation with optional coarsening.

        Args:
            other: FST to compose with
            coarsen: Whether to apply pruning/coarsening

        Returns:
            The composed FST
        """
        if coarsen and FST.PRUNING is not None:
            keep = FST.PRUNING(self, other)  # pylint: disable=E1102
            result = self._pruned_compose(other, keep, keep.keep_arc)

        else:
            result = self._pruned_compose(
                other, lambda x: True, lambda i, label, j: True
            )

        return result

    # TODO: add assertions for the 'bad' epsilon cases to ensure users aren't using this method incorrectly.
    def _pruned_compose(self, other, keep, keep_arc):
        """Implements pruned on-the-fly composition of FSTs.

        Args:
            other: FST to compose with
            keep: Function that determines which states to keep
            keep_arc: Function that determines which arcs to keep

        Returns:
            The composed FST with pruning applied
        """
        C = FST(R=self.R)

        # index arcs in `other` to so that they are fast against later
        tmp = defaultdict(list)
        for i, (a, b), j, w in other.arcs():
            tmp[i, a].append((b, j, w))

        visited = set()
        stack = []

        # add initial states
        for P, w1 in self.I:
            for Q, w2 in other.I:
                PQ = (P, Q)

                if not keep(PQ):
                    continue

                C.add_I(PQ, w1 * w2)
                visited.add(PQ)
                stack.append(PQ)

        # traverse the machine using depth-first search
        while stack:
            P, Q = PQ = stack.pop()

            # (q,p) is simultaneously a final state in the respective machines
            if P in self.stop and Q in other.stop:
                C.add_F(PQ, self.stop[P] * other.stop[Q])
                # Note: final states are not necessarily absorbing -> fall thru

            # Arcs of the composition machine are given by a cross-product-like
            # construction that matches an arc labeled `a:b` with an arc labeled
            # `b:c` in the left and right machines respectively.
            for (a, b), , w1 in self.arcs(P):
                for c, , w2 in tmp[Q, b]:
                    assert b != EPSILON

                    PʼQʼ = (, )

                    if not keep(PʼQʼ) or not keep_arc(PQ, (a, c), PʼQʼ):
                        continue

                    C.add_arc(PQ, (a, c), PʼQʼ, w1 * w2)

                    if PʼQʼ not in visited:
                        stack.append(PʼQʼ)
                        visited.add(PʼQʼ)

        return C

    def _augment_epsilon_transitions(self, idx):
        """Augments the FST by changing the appropriate epsilon transitions to
        epsilon_1 or epsilon_2 transitions to be able to perform the composition
        correctly.  See Fig. 7 on p. 17 of Mohri, "Weighted Automata Algorithms".

        Args:
            idx: 0 if this is the first FST in composition, 1 if second

        Returns:
            FST with augmented epsilon transitions
        """
        assert idx in [0, 1]

        T = self.spawn(keep_init=True, keep_stop=True)

        for i in self.states:
            if idx == 0:
                T.add_arc(i, (ε, ε_1), i, self.R.one)
            else:
                T.add_arc(i, (ε_2, ε), i, self.R.one)
            for ab, j, w in self.arcs(i):
                if idx == 0 and ab[1] == ε:
                    ab = (ab[0], ε_2)
                elif idx == 1 and ab[0] == ε:
                    ab = (ε_1, ab[1])
                T.add_arc(i, ab, j, w)

        return T

    @classmethod
    def diag(cls, fsa):
        """Convert FSA to diagonal FST that maps strings to themselves.

        Args:
            fsa: Input FSA to convert

        Returns:
            FST that maps each string accepted by fsa to itself with same weight
        """
        fst = cls(fsa.R)
        for i, a, j, w in fsa.arcs():
            fst.add_arc(i, (a, a), j, w)
        for i, w in fsa.I:
            fst.add_I(i, w)
        for i, w in fsa.F:
            fst.add_F(i, w)
        return fst

    def coarsen(self, N, A, B):
        """Create coarsened Boolean FST by mapping states and symbols.

        Args:
            N: Function mapping states to coarsened states
            A: Function mapping input symbols to coarsened input symbols
            B: Function mapping output symbols to coarsened output symbols

        Returns:
            Coarsened Boolean FST
        """
        m = FST(Boolean)
        for i in self.start:
            m.add_I(N(i), Boolean.one)
        for i in self.stop:
            m.add_F(N(i), Boolean.one)
        for i, (a, b), j, _ in self.arcs():
            m.add_arc(N(i), (A(a), B(b)), N(j), Boolean.one)
        return m

T cached property

Return the transpose of this FST by swapping input/output labels.

Returns:

Type Description

A new FST with input/output labels swapped

__call__(x, y)

Compute the weight of mapping input x to output y.

If x or y is None, returns a weighted language representing the cross section.

Parameters:

Name Type Description Default
x

Input string or None

required
y

Output string or None

required

Returns:

Type Description

Weight of mapping x to y, or a WFSA representing the cross section if x or y is None

Source code in genlm/grammar/fst.py
def __call__(self, x, y):
    """Compute the weight of mapping input x to output y.

    If x or y is None, returns a weighted language representing the cross section.

    Args:
        x: Input string or None
        y: Output string or None

    Returns:
        Weight of mapping x to y, or a WFSA representing the cross section if x or y is None
    """
    if x is not None and y is not None:
        x = FST.from_string(x, self.R)
        y = FST.from_string(y, self.R)
        return (x @ self @ y).total_weight()

    elif x is not None and y is None:
        x = FST.from_string(x, self.R)
        return (x @ self).project(1)

    elif x is None and y is not None:
        y = FST.from_string(y, self.R)
        return (self @ y).project(0)

    else:
        return self

__init__(R)

Initialize an empty FST.

Parameters:

Name Type Description Default
R

The semiring for transition weights

required
Source code in genlm/grammar/fst.py
def __init__(self, R):
    """Initialize an empty FST.

    Args:
        R: The semiring for transition weights
    """
    super().__init__(R=R)

    # alphabets
    self.A = set()  # input alphabet
    self.B = set()  # output alphabet

__matmul__(other)

Compose this FST with another FST or automaton.

Parameters:

Name Type Description Default
other

Another FST, CFG or automaton to compose with

required

Returns:

Type Description

The composed FST

Source code in genlm/grammar/fst.py
def __matmul__(self, other):
    """Compose this FST with another FST or automaton.

    Args:
        other: Another FST, CFG or automaton to compose with

    Returns:
        The composed FST
    """
    if not isinstance(other, FST):
        from genlm.grammar.cfg import CFG

        if isinstance(other, CFG):
            return other @ self.T
        else:
            other = FST.diag(other)

    # minor efficiency trick: it's slightly more efficient to associate the composition as follows
    if len(self.states) < len(other.states):
        return (
            self._augment_epsilon_transitions(0)  # rename epsilons on the right
            ._compose(
                epsilon_filter_fst(self.R, self.B), coarsen=False
            )  # this FST carefully combines the special epsilons
            ._compose(
                other._augment_epsilon_transitions(1)
            )  # rename epsilons on th left
        )

    else:
        return self._augment_epsilon_transitions(
            0
        )._compose(  # rename epsilons on the right
            epsilon_filter_fst(
                self.R, self.B
            )._compose(  # this FST carefully combines the special epsilons
                other._augment_epsilon_transitions(1), coarsen=False
            )
        )  # rename epsilons on th left

add_arc(i, ab, j, w)

Add a weighted transition between states.

Parameters:

Name Type Description Default
i

Source state

required
ab

Tuple (a,b) of input/output symbols, or EPSILON

required
j

Target state

required
w

Weight of the transition

required

Returns:

Type Description

self

Source code in genlm/grammar/fst.py
def add_arc(self, i, ab, j, w):  # pylint: disable=arguments-renamed
    """Add a weighted transition between states.

    Args:
        i: Source state
        ab: Tuple (a,b) of input/output symbols, or EPSILON
        j: Target state
        w: Weight of the transition

    Returns:
        self
    """
    if ab != EPSILON:
        (a, b) = ab
        self.A.add(a)
        self.B.add(b)
    return super().add_arc(i, ab, j, w)

coarsen(N, A, B)

Create coarsened Boolean FST by mapping states and symbols.

Parameters:

Name Type Description Default
N

Function mapping states to coarsened states

required
A

Function mapping input symbols to coarsened input symbols

required
B

Function mapping output symbols to coarsened output symbols

required

Returns:

Type Description

Coarsened Boolean FST

Source code in genlm/grammar/fst.py
def coarsen(self, N, A, B):
    """Create coarsened Boolean FST by mapping states and symbols.

    Args:
        N: Function mapping states to coarsened states
        A: Function mapping input symbols to coarsened input symbols
        B: Function mapping output symbols to coarsened output symbols

    Returns:
        Coarsened Boolean FST
    """
    m = FST(Boolean)
    for i in self.start:
        m.add_I(N(i), Boolean.one)
    for i in self.stop:
        m.add_F(N(i), Boolean.one)
    for i, (a, b), j, _ in self.arcs():
        m.add_arc(N(i), (A(a), B(b)), N(j), Boolean.one)
    return m

diag(fsa) classmethod

Convert FSA to diagonal FST that maps strings to themselves.

Parameters:

Name Type Description Default
fsa

Input FSA to convert

required

Returns:

Type Description

FST that maps each string accepted by fsa to itself with same weight

Source code in genlm/grammar/fst.py
@classmethod
def diag(cls, fsa):
    """Convert FSA to diagonal FST that maps strings to themselves.

    Args:
        fsa: Input FSA to convert

    Returns:
        FST that maps each string accepted by fsa to itself with same weight
    """
    fst = cls(fsa.R)
    for i, a, j, w in fsa.arcs():
        fst.add_arc(i, (a, a), j, w)
    for i, w in fsa.I:
        fst.add_I(i, w)
    for i, w in fsa.F:
        fst.add_F(i, w)
    return fst

from_pairs(pairs, R) staticmethod

Create an FST accepting the given input-output string pairs.

Parameters:

Name Type Description Default
pairs

List of (input_string, output_string) tuples

required
R

Semiring for weights

required

Returns:

Type Description

An FST accepting the given string pairs with weight one

Source code in genlm/grammar/fst.py
@staticmethod
def from_pairs(pairs, R):
    """Create an FST accepting the given input-output string pairs.

    Args:
        pairs: List of (input_string, output_string) tuples
        R: Semiring for weights

    Returns:
        An FST accepting the given string pairs with weight one
    """
    p = FST(R)
    p.add_I(0, R.one)
    p.add_F(1, R.one)
    for i, (xs, ys) in enumerate(pairs):
        p.add_arc(0, EPSILON, (i, 0), R.one)
        for j, (x, y) in enumerate(zip_longest(xs, ys, fillvalue=EPSILON)):
            p.add_arc((i, j), (x, y), (i, j + 1), R.one)
        p.add_arc((i, max(len(xs), len(ys))), EPSILON, 1, R.one)
    return p

from_string(xs, R, w=None) classmethod

Create an FST that accepts only the given string with optional weight.

Parameters:

Name Type Description Default
xs

Input string

required
R

Semiring for weights

required
w

Optional weight for the string

None

Returns:

Type Description

An FST accepting only xs with weight w

Source code in genlm/grammar/fst.py
@classmethod
def from_string(cls, xs, R, w=None):
    """Create an FST that accepts only the given string with optional weight.

    Args:
        xs: Input string
        R: Semiring for weights
        w: Optional weight for the string

    Returns:
        An FST accepting only xs with weight w
    """
    return cls.diag(WFSA.from_string(xs=xs, R=R, w=w))

project(axis)

Project the FST onto one of its components to create a WFSA.

Parameters:

Name Type Description Default
axis

0 for input projection, 1 for output projection

required

Returns:

Type Description

A WFSA over the projected alphabet

Source code in genlm/grammar/fst.py
def project(self, axis):
    """Project the FST onto one of its components to create a WFSA.

    Args:
        axis: 0 for input projection, 1 for output projection

    Returns:
        A WFSA over the projected alphabet
    """
    assert axis in [0, 1]
    A = WFSA(R=self.R)
    for i, (a, b), j, w in self.arcs():
        if axis == 0:
            A.add_arc(i, a, j, w)
        else:
            A.add_arc(i, b, j, w)
    for i, w in self.I:
        A.add_I(i, w)
    for i, w in self.F:
        A.add_F(i, w)
    return A

prune_to_alphabet(A, B)

Remove transitions with labels not in the given alphabets.

Parameters:

Name Type Description Default
A

Set of allowed input symbols, or None to allow all

required
B

Set of allowed output symbols, or None to allow all

required

Returns:

Type Description

A new FST with invalid transitions removed

Source code in genlm/grammar/fst.py
def prune_to_alphabet(self, A, B):
    """Remove transitions with labels not in the given alphabets.

    Args:
        A: Set of allowed input symbols, or None to allow all
        B: Set of allowed output symbols, or None to allow all

    Returns:
        A new FST with invalid transitions removed
    """
    T = self.spawn()
    for i, (a, b), j, w in self.arcs():
        if (A is None or a in A) and (B is None or b in B):
            T.add_arc(i, (a, b), j, w)
    for q, w in self.I:
        T.add_I(q, w)
    for q, w in self.F:
        T.add_F(q, w)
    return T.trim

set_arc(i, ab, j, w)

Set the weight of a transition between states.

Parameters:

Name Type Description Default
i

Source state

required
ab

Tuple (a,b) of input/output symbols, or EPSILON

required
j

Target state

required
w

New weight for the transition

required

Returns:

Type Description

self

Source code in genlm/grammar/fst.py
def set_arc(self, i, ab, j, w):  # pylint: disable=arguments-renamed
    """Set the weight of a transition between states.

    Args:
        i: Source state
        ab: Tuple (a,b) of input/output symbols, or EPSILON
        j: Target state
        w: New weight for the transition

    Returns:
        self
    """
    if ab != EPSILON:
        (a, b) = ab
        self.A.add(a)
        self.B.add(b)
    return super().set_arc(i, ab, j, w)

WFSA

Bases: WFSA

Weighted finite-state automata where weights are a field (e.g., real-valued).

Source code in genlm/grammar/wfsa/field_wfsa.py
class WFSA(base.WFSA):
    """
    Weighted finite-state automata where weights are a field (e.g., real-valued).
    """

    def __init__(self, R=Float):
        super().__init__(R=R)

    def __hash__(self):
        return hash(self.simple)

    def threshold(self, threshold):
        "Drop init, arcs, final below a given abs-threshold."
        m = self.__class__(self.R)
        for q, w in self.I:
            if abs(w) >= threshold:
                m.add_I(q, w)
        for i, a, j, w in self.arcs():
            if abs(w) >= threshold:
                m.add_arc(i, a, j, w)
        for q, w in self.F:
            if abs(w) >= threshold:
                m.add_F(q, w)
        return m

    def graphviz(
        self,
        fmt=lambda x: f"{round(x, 3):g}" if isinstance(x, (float, int)) else str(x),
        **kwargs,
    ):  # pylint: disable=arguments-differ
        return super().graphviz(fmt=fmt, **kwargs)

    @cached_property
    def simple(self):
        self = self.epsremove.renumber

        S = self.dim
        start = np.full(S, self.R.zero)
        arcs = {a: np.full((S, S), self.R.zero) for a in self.alphabet}
        stop = np.full(S, self.R.zero)

        for i, w in self.I:
            start[i] += w
        for i, a, j, w in self.arcs():
            arcs[a][i, j] += w
        for i, w in self.F:
            stop[i] += w

        assert EPSILON not in arcs

        return Simple(start, arcs, stop)

    def __eq__(self, other):
        return self.simple == other.simple

    def counterexample(self, other):
        return self.simple.counterexample(other.simple)

    @cached_property
    def min(self):
        return self.simple.min.to_wfsa()

    #    @cached_property
    #    def epsremove(self):
    #        return self.simple.to_wfsa()

    def multiplicity(self, m):
        return WFSA.lift(EPSILON, m) * self

    @classmethod
    def lift(cls, x, w, R=None):
        if R is None:
            R = Float
        m = cls(R=R)
        m.add_I(0, R.one)
        m.add_arc(0, x, 1, w)
        m.add_F(1, R.one)
        return m

threshold(threshold)

Drop init, arcs, final below a given abs-threshold.

Source code in genlm/grammar/wfsa/field_wfsa.py
def threshold(self, threshold):
    "Drop init, arcs, final below a given abs-threshold."
    m = self.__class__(self.R)
    for q, w in self.I:
        if abs(w) >= threshold:
            m.add_I(q, w)
    for i, a, j, w in self.arcs():
        if abs(w) >= threshold:
            m.add_arc(i, a, j, w)
    for q, w in self.F:
        if abs(w) >= threshold:
            m.add_F(q, w)
    return m

add_EOS(cfg, eos=None)

Add an end-of-sequence symbol to a CFG's language.

Transforms the grammar to append the EOS symbol to every string it generates.

Parameters:

Name Type Description Default
cfg CFG

The input grammar

required
eos optional

The end-of-sequence symbol to add. Defaults to ▪.

None

Returns:

Type Description
CFG

A new grammar that generates strings ending in EOS

Raises:

Type Description
AssertionError

If EOS is already in the grammar's vocabulary

Source code in genlm/grammar/cfglm.py
def add_EOS(cfg, eos=None):
    """Add an end-of-sequence symbol to a CFG's language.

    Transforms the grammar to append the EOS symbol to every string it generates.

    Args:
        cfg (CFG): The input grammar
        eos (optional): The end-of-sequence symbol to add. Defaults to ▪.

    Returns:
        (CFG): A new grammar that generates strings ending in EOS

    Raises:
        AssertionError: If EOS is already in the grammar's vocabulary

    """
    S = _gen_nt("<START>")
    new = cfg.spawn(S=S)
    eos = eos or EOS
    assert eos not in cfg.V
    new.V.add(eos)
    new.add(cfg.R.one, S, cfg.S, eos)
    for r in cfg:
        new.add(r.w, r.head, *r.body)
    return new

locally_normalize(self, **kwargs)

Locally normalize the grammar's rule weights.

Returns a transformed grammar where: 1. The total weight of rules with the same head symbol sums to one 2. Each derivation's weight is proportional to the original grammar (differs only by a multiplicative normalization constant)

Parameters:

Name Type Description Default
**kwargs

Additional arguments passed to self.agenda()

{}

Returns:

Type Description
CFG

A new grammar with locally normalized weights

Source code in genlm/grammar/cfglm.py
def locally_normalize(self, **kwargs):
    """Locally normalize the grammar's rule weights.

    Returns a transformed grammar where:
    1. The total weight of rules with the same head symbol sums to one
    2. Each derivation's weight is proportional to the original grammar
       (differs only by a multiplicative normalization constant)

    Args:
        **kwargs: Additional arguments passed to self.agenda()

    Returns:
        (CFG): A new grammar with locally normalized weights
    """
    new = self.spawn()
    Z = self.agenda(**kwargs)
    for r in self:
        if Z[r.head] == 0:
            continue
        new.add(r.w * Z.product(r.body) / Z[r.head], r.head, *r.body)
    return new