Skip to content

cfg

CFG

Weighted Context-free Grammar

A weighted context-free grammar consists of:

  • R: A semiring that defines the weights

  • S: A start symbol (nonterminal)

  • V: A set of terminal symbols (vocabulary)

  • N: A set of nonterminal symbols

  • rules: A list of weighted production rules

Each rule has the form: w: X -> Y1 Y2 ... Yn where:

  • w is a weight from the semiring R

  • X is a nonterminal symbol

  • Y1...Yn are terminal or nonterminal symbols

Source code in genlm/grammar/cfg.py
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
class CFG:
    """
    Weighted Context-free Grammar

    A weighted context-free grammar consists of:\n
    - `R`: A semiring that defines the weights\n
    - `S`: A start symbol (nonterminal)\n
    - `V`: A set of terminal symbols (vocabulary)\n
    - `N`: A set of nonterminal symbols\n
    - `rules`: A list of weighted production rules\n

    Each rule has the form: w: X -> Y1 Y2 ... Yn where:\n
    - w is a weight from the semiring R\n
    - X is a nonterminal symbol\n
    - Y1...Yn are terminal or nonterminal symbols\n
    """

    def __init__(self, R, S, V):
        """
        Initialize a weighted CFG.

        Args:
            R: The semiring for rule weights
            S: The start symbol (nonterminal)
            V: The set of terminal symbols (vocabulary)
        """
        self.R = R  # semiring
        self.V = V  # alphabet
        self.N = {S}  # nonterminals
        self.S = S  # unique start symbol
        self.rules = []  # rules
        self._trim_cache = [None, None]

    def __repr__(self):
        """Return string representation of the grammar."""
        return "Grammar {\n%s\n}" % "\n".join(f"  {r}" for r in self)

    def _repr_html_(self):
        """Return HTML representation of the grammar for Jupyter notebooks."""
        return f'<pre style="width: fit-content; text-align: left; border: thin solid black; padding: 0.5em;">{self}</pre>'

    @classmethod
    def from_string(
        cls,
        string,
        semiring,
        comment="#",
        start="S",
        is_terminal=lambda x: not x[0].isupper(),
    ):
        """
        Create a CFG from a string representation.

        Args:
            string: The grammar rules as a string
            semiring: The semiring for rule weights
            comment: Comment character to ignore lines (default: '#')
            start: Start symbol (default: 'S')
            is_terminal: Function to identify terminal symbols (default: lowercase first letter)

        Returns:
            A new CFG instance
        """
        V = set()
        cfg = cls(R=semiring, S=start, V=V)
        string = string.replace("->", "→")  # synonym for the arrow
        for line in string.split("\n"):
            line = line.strip()
            if not line or line.startswith(comment):
                continue
            try:
                [(w, lhs, rhs)] = re.findall(r"(.*):\s*(\S+)\s*→\s*(.*)$", line)
                lhs = lhs.strip()
                rhs = rhs.strip().split()
                for x in rhs:
                    if is_terminal(x):
                        V.add(x)
                cfg.add(semiring.from_string(w), lhs, *rhs)
            except ValueError:
                raise ValueError(f"bad input line:\n{line}")  # pylint: disable=W0707
        return cfg

    def __getitem__(self, root):
        """
        Return a grammar that denotes the sublanguage of the nonterminal `root`.

        Args:
            root: The nonterminal to use as the new start symbol

        Returns:
            A new CFG with root as the start symbol
        """
        new = self.spawn(S=root)
        for r in self:
            new.add(r.w, r.head, *r.body)
        return new

    def __len__(self):
        """Return number of rules in the grammar."""
        return len(self.rules)

    def __call__(self, xs):
        """
        Compute the total weight of the sequence xs.

        Args:
            xs: A sequence of terminal symbols

        Returns:
            The total weight of all derivations of xs
        """
        self = self.cnf  # need to do this here because the start symbol might change
        return self._parse_chart(xs)[0, self.S, len(xs)]

    def _parse_chart(self, xs):
        """
        Implements CKY algorithm for evaluating the total weight of the xs sequence.

        Args:
            xs: A sequence of terminal symbols

        Returns:
            A chart containing the weights of all subderivations
        """
        (nullary, terminal, binary) = self._cnf  # will convert to CNF
        N = len(xs)
        # nullary rule
        c = self.R.chart()
        for i in range(N + 1):
            c[i, self.S, i] += nullary
        # preterminal rules
        for i in range(N):
            for r in terminal[xs[i]]:
                c[i, r.head, i + 1] += r.w
        # binary rules
        for span in range(2, N + 1):
            for i in range(N - span + 1):
                k = i + span
                for j in range(i + 1, k):
                    for r in binary:
                        X, [Y, Z] = r.head, r.body
                        c[i, X, k] += r.w * c[i, Y, j] * c[j, Z, k]
        return c

    def language(self, depth):
        """
        Enumerate strings generated by this cfg by derivations up to the given depth.

        Args:
            depth: Maximum derivation depth to consider

        Returns:
            A chart containing the weighted language up to the given depth
        """
        lang = self.R.chart()
        for d in self.derivations(self.S, depth):
            lang[d.Yield()] += d.weight()
        return lang

    @cached_property
    def rhs(self):
        """
        Map from each nonterminal to the list of rules with it as their left-hand side.

        Returns:
            A dict mapping nonterminals to lists of rules
        """
        rhs = defaultdict(list)
        for r in self:
            rhs[r.head].append(r)
        return rhs

    def is_terminal(self, x):
        """Return True if x is a terminal symbol."""
        return x in self.V

    def is_nonterminal(self, X):
        """Return True if X is a nonterminal symbol."""
        return not self.is_terminal(X)

    def __iter__(self):
        """Iterate over the rules in the grammar."""
        return iter(self.rules)

    @property
    def size(self):
        """Return total size of the grammar (sum of rule lengths)."""
        return sum(1 + len(r.body) for r in self)

    @property
    def num_rules(self):
        """Return number of rules in the grammar."""
        return len(self.rules)

    @property
    def expected_length(self):
        """
        Compute the expected length of a string using the Expectation semiring.

        Returns:
            The expected length of strings generated by this grammar

        Raises:
            AssertionError: If grammar is not over the Float semiring
        """
        assert self.R == Float, (
            "This method only supports grammars over the Float semiring"
        )
        new_cfg = self.__class__(R=Expectation, S=self.S, V=self.V)
        for r in self:
            new_cfg.add(
                Expectation(r.w, r.w * sum(self.is_terminal(y) for y in r.body)),
                r.head,
                *r.body,
            )
        return new_cfg.treesum().score[1]

    def spawn(self, *, R=None, S=None, V=None):
        """
        Create an empty grammar with the same R, S, and V.

        Args:
            R: Optional new semiring
            S: Optional new start symbol
            V: Optional new vocabulary

        Returns:
            A new empty CFG with specified parameters
        """
        return self.__class__(
            R=self.R if R is None else R,
            S=self.S if S is None else S,
            V=set(self.V) if V is None else V,
        )

    def add(self, w, head, *body):
        """
        Add a rule of the form w: head -> body1, body2, ... body_k.

        Args:
            w: The rule weight
            head: The left-hand side nonterminal
            *body: The right-hand side symbols

        Returns:
            The added rule, or None if weight is zero
        """
        if w == self.R.zero:
            return  # skip rules with weight zero
        self.N.add(head)
        r = Rule(w, head, body)
        self.rules.append(r)
        return r

    def renumber(self):
        """
        Rename nonterminals to integers.

        Returns:
            A new CFG with integer nonterminals
        """
        i = Integerizer()
        max_v = max((x for x in self.V if isinstance(x, int)), default=0)
        return self.rename(lambda x: i(x) + max_v + 1)

    def rename(self, f):
        """
        Return a new grammar that is the result of applying f to each nonterminal.

        Args:
            f: Function to rename nonterminals

        Returns:
            A new CFG with renamed nonterminals
        """
        new = self.spawn(S=f(self.S))
        for r in self:
            new.add(
                r.w, f(r.head), *((y if self.is_terminal(y) else f(y) for y in r.body))
            )
        return new

    def map_values(self, f, R):
        """
        Return a new grammar that is the result of applying f: self.R -> R to each rule's weight.

        Args:
            f: Function to map weights
            R: New semiring for weights

        Returns:
            A new CFG with mapped weights
        """
        new = self.spawn(R=R)
        for r in self:
            new.add(f(r.w), r.head, *r.body)
        return new

    def assert_equal(self, other, verbose=False, throw=True):
        """
        Assertion for the equality of self and other modulo rule reordering.

        Args:
            other: The grammar to compare against
            verbose: If True, print differences
            throw: If True, raise AssertionError on inequality

        Raises:
            AssertionError: If grammars are not equal and throw=True
        """
        assert verbose or throw
        if isinstance(other, str):
            other = self.__class__.from_string(other, self.R)
        if verbose:
            # TODO: need to check the weights in the print out; we do it in the assertion
            S = set(self.rules)
            G = set(other.rules)
            for r in sorted(S | G, key=str):
                if r in S and r in G:
                    continue
                # if r in S and r not in G: continue
                # if r not in S and r in G: continue
                print(
                    colors.mark(r in S),
                    # colors.mark(r in S and r in G),
                    colors.mark(r in G),
                    r,
                )
        assert not throw or Counter(self.rules) == Counter(other.rules), (
            f"\n\nhave=\n{str(self)}\nwant=\n{str(other)}"
        )

    def treesum(self, **kwargs):
        """
        Total weight of the start symbol.

        Returns:
            The total weight of all derivations from the start symbol
        """
        return self.agenda(**kwargs)[self.S]

    def trim(self, bottomup_only=False):
        """
        Return an equivalent grammar with no dead or useless nonterminals or rules.

        Args:
            bottomup_only: If True, only remove non-generating nonterminals

        Returns:
            A new trimmed CFG
        """
        if self._trim_cache[bottomup_only] is not None:
            return self._trim_cache[bottomup_only]

        C = set(self.V)
        C.update(e.head for e in self.rules if len(e.body) == 0)

        incoming = defaultdict(list)
        outgoing = defaultdict(list)
        for e in self:
            incoming[e.head].append(e)
            for b in e.body:
                outgoing[b].append(e)

        agenda = set(C)
        while agenda:
            x = agenda.pop()
            for e in outgoing[x]:
                if all((b in C) for b in e.body):
                    if e.head not in C:
                        C.add(e.head)
                        agenda.add(e.head)

        if bottomup_only:
            val = self._trim(C)
            self._trim_cache[bottomup_only] = val
            val._trim_cache[bottomup_only] = val
            return val

        T = {self.S}
        agenda.update(T)
        while agenda:
            x = agenda.pop()
            for e in incoming[x]:
                # assert e.head in T
                for b in e.body:
                    if b not in T and b in C:
                        T.add(b)
                        agenda.add(b)

        val = self._trim(T)
        self._trim_cache[bottomup_only] = val
        val._trim_cache[bottomup_only] = val
        return val

    def cotrim(self):
        """
        Trim the grammar so that all nonterminals are generating.

        Returns:
            A new CFG with only generating nonterminals
        """
        return self.trim(bottomup_only=True)

    def _trim(self, symbols):
        """
        Helper method for trim() - creates new grammar with only given symbols.

        Args:
            symbols: Set of symbols to keep

        Returns:
            A new CFG with only rules using the given symbols
        """
        new = self.spawn()
        for p in self:
            if p.head in symbols and p.w != self.R.zero and set(p.body) <= symbols:
                new.add(p.w, p.head, *p.body)
        return new

    # ___________________________________________________________________________
    # Derivation enumeration

    def derivations(self, X, H):
        """
        Enumerate derivations of symbol X with height <= H.

        Args:
            X: The symbol to derive from (default: start symbol)
            H: Maximum derivation height

        Yields:
            Derivation objects representing derivation trees
        """
        if X is None:
            X = self.S
        if self.is_terminal(X):
            yield X
        elif H <= 0:
            return
        else:
            for r in self.rhs[X]:
                for ys in self._derivations_list(r.body, H - 1):
                    yield Derivation(r, X, *ys)

    def _derivations_list(self, Xs, H):
        """
        Helper method for derivations; expands any list of symbols X up to depth H.

        Args:
            Xs: List of symbols to derive
            H: Maximum derivation height

        Yields:
            Tuples of derivations
        """
        if len(Xs) == 0:
            yield ()
        else:
            for x in self.derivations(Xs[0], H):
                for xs in self._derivations_list(Xs[1:], H):
                    yield (x, *xs)

    # ___________________________________________________________________________
    # Transformations

    def _unary_graph(self):
        """
        Compute the matrix closure of unary rules.

        Returns:
            A WeightedGraph representing unary rule closure
        """
        A = WeightedGraph(self.R)
        for r in self:
            if len(r.body) == 1 and self.is_nonterminal(r.body[0]):
                A[r.head, r.body[0]] += r.w
        A.N |= self.N
        return A

    def _unary_graph_transpose(self):
        """
        Compute the matrix closure of unary rules (transposed).

        Returns:
            A WeightedGraph representing transposed unary rule closure
        """
        A = WeightedGraph(self.R)
        for r in self:
            if len(r.body) == 1 and self.is_nonterminal(r.body[0]):
                A[r.body[0], r.head] += r.w
        A.N |= self.N
        return A

    def unaryremove(self):
        """
        Return an equivalent grammar with no unary rules.

        Returns:
            A new CFG without unary rules
        """
        W = self._unary_graph().closure_scc_based()
        # W = self._unary_graph().closure_reference()

        new = self.spawn()
        for r in self:
            if len(r.body) == 1 and self.is_nonterminal(r.body[0]):
                continue
            for Y in self.N:
                new.add(W[Y, r.head] * r.w, Y, *r.body)

        return new

    def has_unary_cycle(self):
        """
        Check if the grammar has unary cycles.

        Returns:
            True if the grammar contains unary cycles
        """
        f = self._unary_graph().buckets
        return any(
            True for r in self if len(r.body) == 1 and f.get(r.head) == f.get(r.body[0])
        )

    def unarycycleremove(self, trim=True):
        """
        Return an equivalent grammar with no unary cycles.

        Args:
            trim: If True, trim the resulting grammar

        Returns:
            A new CFG without unary cycles
        """

        def bot(x):
            return x if x in acyclic else (x, "bot")

        G = self._unary_graph()

        new = self.spawn(S=self.S)

        bucket = G.buckets

        acyclic = set()
        for nodes, _ in G.Blocks:
            if len(nodes) == 1:
                [X] = nodes
                if G[X, X] == self.R.zero:
                    acyclic.add(X)

        # run Lehmann's on each cylical SCC
        for nodes, W in G.Blocks:
            if len(nodes) == 1:
                [X] = nodes
                if X in acyclic:
                    continue

            for X1, X2 in W:
                new.add(W[X1, X2], X1, bot(X2))

        for r in self:
            if len(r.body) == 1 and bucket.get(r.body[0]) == bucket[r.head]:
                continue
            new.add(r.w, bot(r.head), *r.body)

        # TODO: figure out how to ensure that the new grammar is trimmed by
        # construction (assuming the input grammar was trim).
        if trim:
            new = new.trim()

        return new

    def nullaryremove(self, binarize=True, trim=True, **kwargs):
        """
        Return an equivalent grammar with no nullary rules except for one at the start symbol.

        Args:
            binarize: If True, binarize the grammar first
            trim: If True, trim the resulting grammar
            **kwargs: Additional arguments passed to _push_null_weights

        Returns:
            A new CFG without nullary rules (except at start)
        """
        # A really wide rule can take a very long time because of the power set
        # in this rule so it is really important to binarize.
        if binarize:
            self = self.binarize()  # pragma: no cover
        self = self.separate_start()
        tmp = self._push_null_weights(self.null_weight(), **kwargs)
        return tmp.trim() if trim else tmp

    def null_weight(self):
        """
        Compute the map from nonterminal to total weight of generating the empty string.

        Returns:
            A dict mapping nonterminals to their null weights
        """
        ecfg = self.spawn(V=set())
        for p in self:
            if not any(self.is_terminal(y) for y in p.body):
                ecfg.add(p.w, p.head, *p.body)
        return ecfg.agenda()

    def null_weight_start(self):
        """
        Compute the null weight of the start symbol.

        Returns:
            The total weight of generating the empty string from the start symbol
        """
        return self.null_weight()[self.S]

    def _push_null_weights(self, null_weight, rename=NotNull):
        """
        Returns a grammar that generates the same weighted language but is nullary-free
        at all nonterminals except its start symbol.

        Args:
            null_weight: Dict mapping nonterminals to their null weights
            rename: Function to rename nonterminals (default: NotNull)

        Returns:
            A new CFG without nullary rules (except at start)
        """
        # Warning: this method might have issues when `separate_start` hasn't
        # been run before.  So we run it rather than leaving it up to chance.
        assert self.S not in {y for r in self for y in r.body}

        def f(x):
            "Rename nonterminal if necessary"
            if (
                null_weight[x] == self.R.zero or x == self.S
            ):  # not necessary; keep old name
                return x
            else:
                return rename(x)

        rcfg = self.spawn()
        rcfg.add(null_weight[self.S], self.S)

        for r in self:
            if len(r.body) == 0:
                continue  # drop nullary rule

            for B in product([0, 1], repeat=len(r.body)):
                v, new_body = r.w, []

                for i, b in enumerate(B):
                    if b:
                        v *= null_weight[r.body[i]]
                    else:
                        new_body.append(f(r.body[i]))

                # exclude the cases that would be new nullary rules!
                if len(new_body) > 0:
                    rcfg.add(v, f(r.head), *new_body)

        return rcfg

    def separate_start(self):
        """
        Ensure that the start symbol does not appear on the RHS of any rule.

        Returns:
            A new CFG with start symbol only on LHS
        """
        # create a new start symbol if the current one appears on the rhs of any existing rule
        if self.S in {y for r in self for y in r.body}:
            S = _gen_nt(self.S)
            new = self.spawn(S=S)
            # preterminal rules
            new.add(self.R.one, S, self.S)
            for r in self:
                new.add(r.w, r.head, *r.body)
            return new
        else:
            return self

    def separate_terminals(self):
        """
        Ensure that each terminal is produced by a preterminal rule.

        Returns:
            A new CFG with terminals only in preterminal rules
        """
        one = self.R.one
        new = self.spawn()

        _preterminal = {}

        def preterminal(x):
            y = _preterminal.get(x)
            if y is None:
                y = new.add(one, _gen_nt(), x)
                _preterminal[x] = y
            return y

        for r in self:
            if len(r.body) == 1 and self.is_terminal(r.body[0]):
                new.add(r.w, r.head, *r.body)
            else:
                new.add(
                    r.w,
                    r.head,
                    *(
                        (preterminal(y).head if self.is_terminal(y) else y)
                        for y in r.body
                    ),
                )

        return new

    def binarize(self):
        """
        Return an equivalent grammar with arity ≤ 2.

        Returns:
            A new CFG with binary rules
        """
        new = self.spawn()

        stack = list(self)
        while stack:
            p = stack.pop()
            if len(p.body) <= 2:
                new.add(p.w, p.head, *p.body)
            else:
                stack.extend(self._fold(p, [(0, 1)]))

        return new

    def _fold(self, p, I):
        """
        Helper method for binarization - folds a rule into binary rules.

        Args:
            p: The rule to fold
            I: List of (start,end) indices for folding

        Returns:
            List of new binary rules
        """
        # new productions
        P, heads = [], []
        for i, j in I:
            head = _gen_nt()
            heads.append(head)
            body = p.body[i : j + 1]
            P.append(Rule(self.R.one, head, body))

        # new "head" production
        body = tuple()
        start = 0
        for (end, n), head in zip(I, heads):
            body += p.body[start:end] + (head,)
            start = n + 1
        body += p.body[start:]
        P.append(Rule(p.w, p.head, body))

        return P

    @cached_property
    def cnf(self):
        """
        Transform this grammar into Chomsky Normal Form (CNF).

        Returns:
            A new CFG in CNF
        """
        new = (
            self.separate_terminals()
            .nullaryremove(binarize=True)
            .trim()
            .unaryremove()
            .trim()
        )
        assert new.in_cnf(), "\n".join(
            str(r) for r in new._find_invalid_cnf_rule()
        )  # pragma: no cover
        return new

    # TODO: make CNF grammars a speciazed subclass of CFG.
    @cached_property
    def _cnf(self):
        """
        Note: Throws an exception if the grammar is not in CNF.

        Returns:
            Tuple of (nullary weight, terminal rules dict, binary rules list)
        """
        nullary = self.R.zero
        terminal = defaultdict(list)
        binary = []
        for r in self:
            if len(r.body) == 0:
                nullary += r.w
                assert r.head == self.S, [self.S, r]
            elif len(r.body) == 1:
                terminal[r.body[0]].append(r)
                assert self.is_terminal(r.body[0])
            else:
                assert len(r.body) == 2
                binary.append(r)
                assert self.is_nonterminal(r.body[0])
                assert self.is_nonterminal(r.body[1])
        return (nullary, terminal, binary)

    def in_cnf(self):
        """
        Return true if the grammar is in CNF.

        Returns:
            True if grammar is in Chomsky Normal Form
        """
        return len(list(self._find_invalid_cnf_rule())) == 0

    def _find_invalid_cnf_rule(self):
        """
        Return true if the grammar is in CNF.

        Yields:
            Rules that violate CNF
        """
        for r in self:
            assert r.head in self.N
            if len(r.body) == 0 and r.head == self.S:
                continue
            elif len(r.body) == 1 and self.is_terminal(r.body[0]):
                continue
            elif len(r.body) == 2 and all(
                self.is_nonterminal(y) and y != self.S for y in r.body
            ):
                continue
            else:
                yield r

    #    def has_nullary(self):
    #        return any((len(p.body) == 0) for p in self if p.head != self.S)

    def unfold(self, i, k):
        """
        Apply the unfolding transformation to rule i and subgoal k.

        Args:
            i: Index of rule to unfold
            k: Index of subgoal in rule body

        Returns:
            A new CFG with the rule unfolded
        """
        assert isinstance(i, int) and isinstance(k, int)
        s = self.rules[i]
        assert self.is_nonterminal(s.body[k])

        new = self.spawn()
        for j, r in enumerate(self):
            if j != i:
                new.add(r.w, r.head, *r.body)

        for r in self.rhs[s.body[k]]:
            new.add(s.w * r.w, s.head, *s.body[:k], *r.body, *s.body[k + 1 :])

        return new

    def dependency_graph(self):
        """
        Head-to-body dependency graph of the rules of the grammar.

        Returns:
            A WeightedGraph representing dependencies between symbols
        """
        deps = WeightedGraph(Boolean)
        for r in self:
            for y in r.body:
                deps[r.head, y] += Boolean.one
        deps.N |= self.N
        deps.N |= self.V
        return deps

    # TODO: the default treesum algorithm should probably be SCC-decomposed newton's method
    # def agenda(self, tol=1e-12, maxiter=float('inf')):
    def agenda(self, tol=1e-12, maxiter=100_000):
        """
        Agenda-based semi-naive evaluation for treesums.

        Args:
            tol: Convergence tolerance
            maxiter: Maximum iterations

        Returns:
            A chart containing the treesum weights
        """
        old = self.R.chart()

        # precompute the mapping from updates to where they need to go
        routing = defaultdict(list)
        for r in self:
            for k in range(len(r.body)):
                routing[r.body[k]].append((r, k))

        deps = self.dependency_graph()
        blocks = deps.blocks
        bucket = deps.buckets

        # helper function
        def update(x, W):
            change[bucket[x]][x] += W

        change = defaultdict(self.R.chart)
        for a in self.V:
            update(a, self.R.one)

        for r in self:
            if len(r.body) == 0:
                update(r.head, r.w)

        b = len(blocks)
        iteration = 0
        while b >= 0:
            iteration += 1

            # Move on to the next block
            if len(change[b]) == 0 or iteration > maxiter:
                b -= 1
                iteration = 0  # reset iteration number for the next bucket
                continue

            u, v = change[b].popitem()

            new = old[u] + v

            if self.R.metric(old[u], new) <= tol:
                continue

            for r, k in routing[u]:
                W = r.w
                for j in range(len(r.body)):
                    if u == r.body[j]:
                        if j < k:
                            W *= new
                        elif j == k:
                            W *= v
                        else:
                            W *= old[u]
                    else:
                        W *= old[r.body[j]]

                update(r.head, W)

            old[u] = new

        return old

    def naive_bottom_up(self, *, tol=1e-12, timeout=100_000):
        "Naive bottom-up evaluation for treesums; better to use `agenda`."

        def _approx_equal(U, V):
            return all((self.R.metric(U[X], V[X]) <= tol) for X in self.N)

        R = self.R
        V = R.chart()
        counter = 0
        while counter < timeout:
            U = self._bottom_up_step(V)
            if _approx_equal(U, V):
                break
            V = U
            counter += 1
        return V

    def _bottom_up_step(self, V):
        R = self.R
        one = R.one
        U = R.chart()
        for a in self.V:
            U[a] = one
        for p in self:
            update = p.w
            for X in p.body:
                if self.is_nonterminal(X):
                    update *= V[X]
            U[p.head] += update
        return U

    def prefix_weight(self, xs):
        "Total weight of all derivations that have `xs` as a prefix."
        return self.prefix_grammar(xs)

    @cached_property
    def prefix_grammar(self):
        """
        The prefix grammar generates the prefix language of the parent grammar.
        PG[x] = sum_[s in Σ^*] G[xs].
        """
        pg = self.spawn()
        W = self.agenda()

        # Generate a unique ID for this prefix transformation to avoid collisions
        arrow_id = _arrow_nt.next_id()

        pg.S = _gen_nt(self.S)
        pg.add(
            self.R.one, pg.S, _arrow_nt(self.S, arrow_id)
        )  # Attach start symbol to arrow nonterminal
        pg.add(
            W[self.S],
            pg.S,
        )  # The prefix weight of the empty string

        for r in self:
            pg.add(r.w, r.head, *r.body)
            w = self.R.one
            for i in range(
                len(r.body) - 1, -1, -1
            ):  # Add the prefixed rules, with the "arrow" non-terminals on the right.
                if self.is_terminal(r.body[i]):  # For a terminal, a^ := a
                    pg.add(r.w * w, _arrow_nt(r.head, arrow_id), *r.body[:i], r.body[i])
                else:  # Otherwise, the arrow is transmitted downwards on the right spine of the derivation.
                    pg.add(
                        r.w * w,
                        _arrow_nt(r.head, arrow_id),
                        *r.body[:i],
                        _arrow_nt(r.body[i], arrow_id),
                    )
                w = w * W[r.body[i]]
        return pg

    def derivatives(self, s):
        "Return the sequence of derivatives for each prefix of `s`."
        M = len(s)
        D = [self]
        for m in range(M):
            D.append(D[m].derivative(s[m]))
        return D

    # Implementation note: This implementation of the derivative grammar
    # performs nullary elimination at the same time.
    def derivative(self, a, i=0):
        "Return a grammar that generates the derivative with respect to `a`."

        def slash(x, y):
            return Slash(x, y, i=i)

        D = self.spawn(S=slash(self.S, a))
        U = self.null_weight()
        for r in self:
            D.add(r.w, r.head, *r.body)
            delta = self.R.one
            for k, y in enumerate(r.body):
                if slash(r.head, a) in self.N:
                    continue  # SKIP!
                if self.is_terminal(y):
                    if y == a:
                        D.add(delta * r.w, slash(r.head, a), *r.body[k + 1 :])
                else:
                    D.add(
                        delta * r.w,
                        slash(r.head, a),
                        slash(r.body[k], a),
                        *r.body[k + 1 :],
                    )
                delta *= U[y]
        return D

    def _compose_bottom_up_epsilon(self, fst):
        "Determine which items of the composition grammar are supported"

        A = set()

        I = defaultdict(set)  # incomplete items
        C = defaultdict(set)  # complete items
        R = defaultdict(set)  # rules indexed by first subgoal; non-nullary

        special_rules = [Rule(self.R.one, a, (EPSILON, a)) for a in self.V] + [
            Rule(self.R.one, Other(self.S), (self.S,)),
            Rule(self.R.one, Other(self.S), (Other(self.S), EPSILON)),
        ]

        for r in itertools.chain(self, special_rules):
            if len(r.body) > 0:
                R[r.body[0]].add(r)

        # we have two base cases:
        #
        # base case 1: arcs
        for i, (a, _), j, _ in fst.arcs():
            A.add((i, a, (), j))  # empty tuple -> the rule 'complete'

        # base case 2: nullary rules
        for r in self:
            if len(r.body) == 0:
                for i in fst.states:
                    A.add((i, r.head, (), i))

        # drain the agenda
        while A:
            (i, X, Ys, j) = A.pop()

            # No pending items ==> the item is complete
            if not Ys:
                if j in C[i, X]:
                    continue
                C[i, X].add(j)

                # combine the newly completed item with incomplete rules that are
                # looking for an item like this one
                for h, X1, Zs in I[i, X]:
                    A.add((h, X1, Zs[1:], j))

                # initialize rules that can start with an item like this one
                for r in R[X]:
                    A.add((i, r.head, r.body[1:], j))

            # Still have pending items ==> advanced the pending items
            else:
                if (i, X, Ys) in I[j, Ys[0]]:
                    continue
                I[j, Ys[0]].add((i, X, Ys))

                for k in C[j, Ys[0]]:
                    A.add((i, X, Ys[1:], k))

        return C

    def __matmul__(self, fst):
        "Return a CFG denoting the pointwise product or composition of `self` and `fs`."

        # coerce something sequence like into a diagonal FST
        if isinstance(fst, (str, tuple)):
            fst = FST.from_string(fst, self.R)
        # coerce something FSA-like into an FST, might throw an error
        if not isinstance(fst, FST):
            fst = fst.to_fst()

        # Initialize the new CFG:
        # - its start symbol is chosen arbitrarily to be `self.S`
        # - its the alphabet changes - it is now 'output' alphabet of the transducer
        new_start = self.S
        new = self.spawn(S=new_start, V=fst.B - {EPSILON})

        # The bottom-up intersection algorithm is a two-pass algorithm
        #
        # Pass 1: Determine the set of items that are possiblly nonzero-valued
        C = self._compose_bottom_up_epsilon(fst)

        special_rules = [Rule(self.R.one, a, (EPSILON, a)) for a in self.V] + [
            Rule(self.R.one, Other(self.S), (self.S,)),
            Rule(self.R.one, Other(self.S), (Other(self.S), EPSILON)),
        ]

        def join(start, Ys):
            """
            Helper method; expands the rule body

            Given Ys = [Y_1, ... Y_K], we will enumerate expansion of the form

            (s_0, Y_1, s_1), (s_1, Y_2, s_2), ..., (s_{k-1}, Y_K, s_K)

            where each (s_k, Y_k, s_k) in the expansion is a completed items
            (i.e., \forall k: (s_k, Y_k, s_k) in C).
            """
            if not Ys:
                yield []
            else:
                for K in C[start, Ys[0]]:
                    for rest in join(K, Ys[1:]):
                        yield [(start, Ys[0], K)] + rest

        start = {I for (I, _) in C}

        for r in itertools.chain(self, special_rules):
            if len(r.body) == 0:
                for s in fst.states:
                    new.add(r.w, (s, r.head, s))
            else:
                for I in start:
                    for rhs in join(I, r.body):
                        K = rhs[-1][-1]
                        new.add(r.w, (I, r.head, K), *rhs)

        for i, wi in fst.start.items():
            for k, wf in fst.stop.items():
                new.add(wi * wf, new_start, (i, Other(self.S), k))

        for i, (a, b), j, w in fst.arcs():
            if b == EPSILON:
                new.add(w, (i, a, j))
            else:
                new.add(w, (i, a, j), b)
        return new

    def truncate_length(self, max_length):
        "Transform this grammar so that it only generates strings with length ≤ `max_length`."
        from genlm.grammar import WFSA

        m = WFSA(self.R)
        m.add_I(0, self.R.one)
        m.add_F(0, self.R.one)
        for t in range(max_length):
            for x in self.V:
                m.add_arc(t, x, t + 1, self.R.one)
            m.add_F(t + 1, self.R.one)
        return self @ m

    def materialize(self, max_length):
        "Return a `Chart` with this grammar's weighted language for strings ≤ `max_length`."
        return self.cnf.language(max_length).filter(lambda x: len(x) <= max_length)

    def to_bytes(self):
        """Convert terminal symbols from strings to bytes representation.

        This method creates a new grammar where all terminal string symbols are
        converted to their UTF-8 byte representation. Non-terminal symbols are
        preserved as-is.

        Returns:
            CFG: A new grammar with byte terminal symbols

        Raises:
            ValueError: If a terminal symbol is not a string
        """
        new = self.spawn(S=self.S, R=self.R, V=set())

        for r in self:
            new_body = []
            for x in r.body:
                if self.is_terminal(x):
                    if not isinstance(x, str):
                        raise ValueError(f"unsupported terminal type: {type(x)}")
                    bs = list(x.encode("utf-8"))
                    for b in bs:
                        new.V.add(b)
                    new_body.extend(bs)
                else:
                    new_body.append(x)
            new.add(r.w, r.head, *new_body)

        return new

cnf cached property

Transform this grammar into Chomsky Normal Form (CNF).

Returns:

Type Description

A new CFG in CNF

expected_length property

Compute the expected length of a string using the Expectation semiring.

Returns:

Type Description

The expected length of strings generated by this grammar

Raises:

Type Description
AssertionError

If grammar is not over the Float semiring

num_rules property

Return number of rules in the grammar.

prefix_grammar cached property

The prefix grammar generates the prefix language of the parent grammar. PG[x] = sum_[s in Σ^*] G[xs].

rhs cached property

Map from each nonterminal to the list of rules with it as their left-hand side.

Returns:

Type Description

A dict mapping nonterminals to lists of rules

size property

Return total size of the grammar (sum of rule lengths).

__call__(xs)

Compute the total weight of the sequence xs.

Parameters:

Name Type Description Default
xs

A sequence of terminal symbols

required

Returns:

Type Description

The total weight of all derivations of xs

Source code in genlm/grammar/cfg.py
def __call__(self, xs):
    """
    Compute the total weight of the sequence xs.

    Args:
        xs: A sequence of terminal symbols

    Returns:
        The total weight of all derivations of xs
    """
    self = self.cnf  # need to do this here because the start symbol might change
    return self._parse_chart(xs)[0, self.S, len(xs)]

__getitem__(root)

Return a grammar that denotes the sublanguage of the nonterminal root.

Parameters:

Name Type Description Default
root

The nonterminal to use as the new start symbol

required

Returns:

Type Description

A new CFG with root as the start symbol

Source code in genlm/grammar/cfg.py
def __getitem__(self, root):
    """
    Return a grammar that denotes the sublanguage of the nonterminal `root`.

    Args:
        root: The nonterminal to use as the new start symbol

    Returns:
        A new CFG with root as the start symbol
    """
    new = self.spawn(S=root)
    for r in self:
        new.add(r.w, r.head, *r.body)
    return new

__init__(R, S, V)

Initialize a weighted CFG.

Parameters:

Name Type Description Default
R

The semiring for rule weights

required
S

The start symbol (nonterminal)

required
V

The set of terminal symbols (vocabulary)

required
Source code in genlm/grammar/cfg.py
def __init__(self, R, S, V):
    """
    Initialize a weighted CFG.

    Args:
        R: The semiring for rule weights
        S: The start symbol (nonterminal)
        V: The set of terminal symbols (vocabulary)
    """
    self.R = R  # semiring
    self.V = V  # alphabet
    self.N = {S}  # nonterminals
    self.S = S  # unique start symbol
    self.rules = []  # rules
    self._trim_cache = [None, None]

__iter__()

Iterate over the rules in the grammar.

Source code in genlm/grammar/cfg.py
def __iter__(self):
    """Iterate over the rules in the grammar."""
    return iter(self.rules)

__len__()

Return number of rules in the grammar.

Source code in genlm/grammar/cfg.py
def __len__(self):
    """Return number of rules in the grammar."""
    return len(self.rules)

__matmul__(fst)

Return a CFG denoting the pointwise product or composition of self and fs.

Source code in genlm/grammar/cfg.py
def __matmul__(self, fst):
    "Return a CFG denoting the pointwise product or composition of `self` and `fs`."

    # coerce something sequence like into a diagonal FST
    if isinstance(fst, (str, tuple)):
        fst = FST.from_string(fst, self.R)
    # coerce something FSA-like into an FST, might throw an error
    if not isinstance(fst, FST):
        fst = fst.to_fst()

    # Initialize the new CFG:
    # - its start symbol is chosen arbitrarily to be `self.S`
    # - its the alphabet changes - it is now 'output' alphabet of the transducer
    new_start = self.S
    new = self.spawn(S=new_start, V=fst.B - {EPSILON})

    # The bottom-up intersection algorithm is a two-pass algorithm
    #
    # Pass 1: Determine the set of items that are possiblly nonzero-valued
    C = self._compose_bottom_up_epsilon(fst)

    special_rules = [Rule(self.R.one, a, (EPSILON, a)) for a in self.V] + [
        Rule(self.R.one, Other(self.S), (self.S,)),
        Rule(self.R.one, Other(self.S), (Other(self.S), EPSILON)),
    ]

    def join(start, Ys):
        """
        Helper method; expands the rule body

        Given Ys = [Y_1, ... Y_K], we will enumerate expansion of the form

        (s_0, Y_1, s_1), (s_1, Y_2, s_2), ..., (s_{k-1}, Y_K, s_K)

        where each (s_k, Y_k, s_k) in the expansion is a completed items
        (i.e., \forall k: (s_k, Y_k, s_k) in C).
        """
        if not Ys:
            yield []
        else:
            for K in C[start, Ys[0]]:
                for rest in join(K, Ys[1:]):
                    yield [(start, Ys[0], K)] + rest

    start = {I for (I, _) in C}

    for r in itertools.chain(self, special_rules):
        if len(r.body) == 0:
            for s in fst.states:
                new.add(r.w, (s, r.head, s))
        else:
            for I in start:
                for rhs in join(I, r.body):
                    K = rhs[-1][-1]
                    new.add(r.w, (I, r.head, K), *rhs)

    for i, wi in fst.start.items():
        for k, wf in fst.stop.items():
            new.add(wi * wf, new_start, (i, Other(self.S), k))

    for i, (a, b), j, w in fst.arcs():
        if b == EPSILON:
            new.add(w, (i, a, j))
        else:
            new.add(w, (i, a, j), b)
    return new

__repr__()

Return string representation of the grammar.

Source code in genlm/grammar/cfg.py
def __repr__(self):
    """Return string representation of the grammar."""
    return "Grammar {\n%s\n}" % "\n".join(f"  {r}" for r in self)

add(w, head, *body)

Add a rule of the form w: head -> body1, body2, ... body_k.

Parameters:

Name Type Description Default
w

The rule weight

required
head

The left-hand side nonterminal

required
*body

The right-hand side symbols

()

Returns:

Type Description

The added rule, or None if weight is zero

Source code in genlm/grammar/cfg.py
def add(self, w, head, *body):
    """
    Add a rule of the form w: head -> body1, body2, ... body_k.

    Args:
        w: The rule weight
        head: The left-hand side nonterminal
        *body: The right-hand side symbols

    Returns:
        The added rule, or None if weight is zero
    """
    if w == self.R.zero:
        return  # skip rules with weight zero
    self.N.add(head)
    r = Rule(w, head, body)
    self.rules.append(r)
    return r

agenda(tol=1e-12, maxiter=100000)

Agenda-based semi-naive evaluation for treesums.

Parameters:

Name Type Description Default
tol

Convergence tolerance

1e-12
maxiter

Maximum iterations

100000

Returns:

Type Description

A chart containing the treesum weights

Source code in genlm/grammar/cfg.py
def agenda(self, tol=1e-12, maxiter=100_000):
    """
    Agenda-based semi-naive evaluation for treesums.

    Args:
        tol: Convergence tolerance
        maxiter: Maximum iterations

    Returns:
        A chart containing the treesum weights
    """
    old = self.R.chart()

    # precompute the mapping from updates to where they need to go
    routing = defaultdict(list)
    for r in self:
        for k in range(len(r.body)):
            routing[r.body[k]].append((r, k))

    deps = self.dependency_graph()
    blocks = deps.blocks
    bucket = deps.buckets

    # helper function
    def update(x, W):
        change[bucket[x]][x] += W

    change = defaultdict(self.R.chart)
    for a in self.V:
        update(a, self.R.one)

    for r in self:
        if len(r.body) == 0:
            update(r.head, r.w)

    b = len(blocks)
    iteration = 0
    while b >= 0:
        iteration += 1

        # Move on to the next block
        if len(change[b]) == 0 or iteration > maxiter:
            b -= 1
            iteration = 0  # reset iteration number for the next bucket
            continue

        u, v = change[b].popitem()

        new = old[u] + v

        if self.R.metric(old[u], new) <= tol:
            continue

        for r, k in routing[u]:
            W = r.w
            for j in range(len(r.body)):
                if u == r.body[j]:
                    if j < k:
                        W *= new
                    elif j == k:
                        W *= v
                    else:
                        W *= old[u]
                else:
                    W *= old[r.body[j]]

            update(r.head, W)

        old[u] = new

    return old

assert_equal(other, verbose=False, throw=True)

Assertion for the equality of self and other modulo rule reordering.

Parameters:

Name Type Description Default
other

The grammar to compare against

required
verbose

If True, print differences

False
throw

If True, raise AssertionError on inequality

True

Raises:

Type Description
AssertionError

If grammars are not equal and throw=True

Source code in genlm/grammar/cfg.py
def assert_equal(self, other, verbose=False, throw=True):
    """
    Assertion for the equality of self and other modulo rule reordering.

    Args:
        other: The grammar to compare against
        verbose: If True, print differences
        throw: If True, raise AssertionError on inequality

    Raises:
        AssertionError: If grammars are not equal and throw=True
    """
    assert verbose or throw
    if isinstance(other, str):
        other = self.__class__.from_string(other, self.R)
    if verbose:
        # TODO: need to check the weights in the print out; we do it in the assertion
        S = set(self.rules)
        G = set(other.rules)
        for r in sorted(S | G, key=str):
            if r in S and r in G:
                continue
            # if r in S and r not in G: continue
            # if r not in S and r in G: continue
            print(
                colors.mark(r in S),
                # colors.mark(r in S and r in G),
                colors.mark(r in G),
                r,
            )
    assert not throw or Counter(self.rules) == Counter(other.rules), (
        f"\n\nhave=\n{str(self)}\nwant=\n{str(other)}"
    )

binarize()

Return an equivalent grammar with arity ≤ 2.

Returns:

Type Description

A new CFG with binary rules

Source code in genlm/grammar/cfg.py
def binarize(self):
    """
    Return an equivalent grammar with arity ≤ 2.

    Returns:
        A new CFG with binary rules
    """
    new = self.spawn()

    stack = list(self)
    while stack:
        p = stack.pop()
        if len(p.body) <= 2:
            new.add(p.w, p.head, *p.body)
        else:
            stack.extend(self._fold(p, [(0, 1)]))

    return new

cotrim()

Trim the grammar so that all nonterminals are generating.

Returns:

Type Description

A new CFG with only generating nonterminals

Source code in genlm/grammar/cfg.py
def cotrim(self):
    """
    Trim the grammar so that all nonterminals are generating.

    Returns:
        A new CFG with only generating nonterminals
    """
    return self.trim(bottomup_only=True)

dependency_graph()

Head-to-body dependency graph of the rules of the grammar.

Returns:

Type Description

A WeightedGraph representing dependencies between symbols

Source code in genlm/grammar/cfg.py
def dependency_graph(self):
    """
    Head-to-body dependency graph of the rules of the grammar.

    Returns:
        A WeightedGraph representing dependencies between symbols
    """
    deps = WeightedGraph(Boolean)
    for r in self:
        for y in r.body:
            deps[r.head, y] += Boolean.one
    deps.N |= self.N
    deps.N |= self.V
    return deps

derivations(X, H)

Enumerate derivations of symbol X with height <= H.

Parameters:

Name Type Description Default
X

The symbol to derive from (default: start symbol)

required
H

Maximum derivation height

required

Yields:

Type Description

Derivation objects representing derivation trees

Source code in genlm/grammar/cfg.py
def derivations(self, X, H):
    """
    Enumerate derivations of symbol X with height <= H.

    Args:
        X: The symbol to derive from (default: start symbol)
        H: Maximum derivation height

    Yields:
        Derivation objects representing derivation trees
    """
    if X is None:
        X = self.S
    if self.is_terminal(X):
        yield X
    elif H <= 0:
        return
    else:
        for r in self.rhs[X]:
            for ys in self._derivations_list(r.body, H - 1):
                yield Derivation(r, X, *ys)

derivative(a, i=0)

Return a grammar that generates the derivative with respect to a.

Source code in genlm/grammar/cfg.py
def derivative(self, a, i=0):
    "Return a grammar that generates the derivative with respect to `a`."

    def slash(x, y):
        return Slash(x, y, i=i)

    D = self.spawn(S=slash(self.S, a))
    U = self.null_weight()
    for r in self:
        D.add(r.w, r.head, *r.body)
        delta = self.R.one
        for k, y in enumerate(r.body):
            if slash(r.head, a) in self.N:
                continue  # SKIP!
            if self.is_terminal(y):
                if y == a:
                    D.add(delta * r.w, slash(r.head, a), *r.body[k + 1 :])
            else:
                D.add(
                    delta * r.w,
                    slash(r.head, a),
                    slash(r.body[k], a),
                    *r.body[k + 1 :],
                )
            delta *= U[y]
    return D

derivatives(s)

Return the sequence of derivatives for each prefix of s.

Source code in genlm/grammar/cfg.py
def derivatives(self, s):
    "Return the sequence of derivatives for each prefix of `s`."
    M = len(s)
    D = [self]
    for m in range(M):
        D.append(D[m].derivative(s[m]))
    return D

from_string(string, semiring, comment='#', start='S', is_terminal=lambda x: not x[0].isupper()) classmethod

Create a CFG from a string representation.

Parameters:

Name Type Description Default
string

The grammar rules as a string

required
semiring

The semiring for rule weights

required
comment

Comment character to ignore lines (default: '#')

'#'
start

Start symbol (default: 'S')

'S'
is_terminal

Function to identify terminal symbols (default: lowercase first letter)

lambda x: not isupper()

Returns:

Type Description

A new CFG instance

Source code in genlm/grammar/cfg.py
@classmethod
def from_string(
    cls,
    string,
    semiring,
    comment="#",
    start="S",
    is_terminal=lambda x: not x[0].isupper(),
):
    """
    Create a CFG from a string representation.

    Args:
        string: The grammar rules as a string
        semiring: The semiring for rule weights
        comment: Comment character to ignore lines (default: '#')
        start: Start symbol (default: 'S')
        is_terminal: Function to identify terminal symbols (default: lowercase first letter)

    Returns:
        A new CFG instance
    """
    V = set()
    cfg = cls(R=semiring, S=start, V=V)
    string = string.replace("->", "→")  # synonym for the arrow
    for line in string.split("\n"):
        line = line.strip()
        if not line or line.startswith(comment):
            continue
        try:
            [(w, lhs, rhs)] = re.findall(r"(.*):\s*(\S+)\s*→\s*(.*)$", line)
            lhs = lhs.strip()
            rhs = rhs.strip().split()
            for x in rhs:
                if is_terminal(x):
                    V.add(x)
            cfg.add(semiring.from_string(w), lhs, *rhs)
        except ValueError:
            raise ValueError(f"bad input line:\n{line}")  # pylint: disable=W0707
    return cfg

has_unary_cycle()

Check if the grammar has unary cycles.

Returns:

Type Description

True if the grammar contains unary cycles

Source code in genlm/grammar/cfg.py
def has_unary_cycle(self):
    """
    Check if the grammar has unary cycles.

    Returns:
        True if the grammar contains unary cycles
    """
    f = self._unary_graph().buckets
    return any(
        True for r in self if len(r.body) == 1 and f.get(r.head) == f.get(r.body[0])
    )

in_cnf()

Return true if the grammar is in CNF.

Returns:

Type Description

True if grammar is in Chomsky Normal Form

Source code in genlm/grammar/cfg.py
def in_cnf(self):
    """
    Return true if the grammar is in CNF.

    Returns:
        True if grammar is in Chomsky Normal Form
    """
    return len(list(self._find_invalid_cnf_rule())) == 0

is_nonterminal(X)

Return True if X is a nonterminal symbol.

Source code in genlm/grammar/cfg.py
def is_nonterminal(self, X):
    """Return True if X is a nonterminal symbol."""
    return not self.is_terminal(X)

is_terminal(x)

Return True if x is a terminal symbol.

Source code in genlm/grammar/cfg.py
def is_terminal(self, x):
    """Return True if x is a terminal symbol."""
    return x in self.V

language(depth)

Enumerate strings generated by this cfg by derivations up to the given depth.

Parameters:

Name Type Description Default
depth

Maximum derivation depth to consider

required

Returns:

Type Description

A chart containing the weighted language up to the given depth

Source code in genlm/grammar/cfg.py
def language(self, depth):
    """
    Enumerate strings generated by this cfg by derivations up to the given depth.

    Args:
        depth: Maximum derivation depth to consider

    Returns:
        A chart containing the weighted language up to the given depth
    """
    lang = self.R.chart()
    for d in self.derivations(self.S, depth):
        lang[d.Yield()] += d.weight()
    return lang

map_values(f, R)

Return a new grammar that is the result of applying f: self.R -> R to each rule's weight.

Parameters:

Name Type Description Default
f

Function to map weights

required
R

New semiring for weights

required

Returns:

Type Description

A new CFG with mapped weights

Source code in genlm/grammar/cfg.py
def map_values(self, f, R):
    """
    Return a new grammar that is the result of applying f: self.R -> R to each rule's weight.

    Args:
        f: Function to map weights
        R: New semiring for weights

    Returns:
        A new CFG with mapped weights
    """
    new = self.spawn(R=R)
    for r in self:
        new.add(f(r.w), r.head, *r.body)
    return new

materialize(max_length)

Return a Chart with this grammar's weighted language for strings ≤ max_length.

Source code in genlm/grammar/cfg.py
def materialize(self, max_length):
    "Return a `Chart` with this grammar's weighted language for strings ≤ `max_length`."
    return self.cnf.language(max_length).filter(lambda x: len(x) <= max_length)

naive_bottom_up(*, tol=1e-12, timeout=100000)

Naive bottom-up evaluation for treesums; better to use agenda.

Source code in genlm/grammar/cfg.py
def naive_bottom_up(self, *, tol=1e-12, timeout=100_000):
    "Naive bottom-up evaluation for treesums; better to use `agenda`."

    def _approx_equal(U, V):
        return all((self.R.metric(U[X], V[X]) <= tol) for X in self.N)

    R = self.R
    V = R.chart()
    counter = 0
    while counter < timeout:
        U = self._bottom_up_step(V)
        if _approx_equal(U, V):
            break
        V = U
        counter += 1
    return V

null_weight()

Compute the map from nonterminal to total weight of generating the empty string.

Returns:

Type Description

A dict mapping nonterminals to their null weights

Source code in genlm/grammar/cfg.py
def null_weight(self):
    """
    Compute the map from nonterminal to total weight of generating the empty string.

    Returns:
        A dict mapping nonterminals to their null weights
    """
    ecfg = self.spawn(V=set())
    for p in self:
        if not any(self.is_terminal(y) for y in p.body):
            ecfg.add(p.w, p.head, *p.body)
    return ecfg.agenda()

null_weight_start()

Compute the null weight of the start symbol.

Returns:

Type Description

The total weight of generating the empty string from the start symbol

Source code in genlm/grammar/cfg.py
def null_weight_start(self):
    """
    Compute the null weight of the start symbol.

    Returns:
        The total weight of generating the empty string from the start symbol
    """
    return self.null_weight()[self.S]

nullaryremove(binarize=True, trim=True, **kwargs)

Return an equivalent grammar with no nullary rules except for one at the start symbol.

Parameters:

Name Type Description Default
binarize

If True, binarize the grammar first

True
trim

If True, trim the resulting grammar

True
**kwargs

Additional arguments passed to _push_null_weights

{}

Returns:

Type Description

A new CFG without nullary rules (except at start)

Source code in genlm/grammar/cfg.py
def nullaryremove(self, binarize=True, trim=True, **kwargs):
    """
    Return an equivalent grammar with no nullary rules except for one at the start symbol.

    Args:
        binarize: If True, binarize the grammar first
        trim: If True, trim the resulting grammar
        **kwargs: Additional arguments passed to _push_null_weights

    Returns:
        A new CFG without nullary rules (except at start)
    """
    # A really wide rule can take a very long time because of the power set
    # in this rule so it is really important to binarize.
    if binarize:
        self = self.binarize()  # pragma: no cover
    self = self.separate_start()
    tmp = self._push_null_weights(self.null_weight(), **kwargs)
    return tmp.trim() if trim else tmp

prefix_weight(xs)

Total weight of all derivations that have xs as a prefix.

Source code in genlm/grammar/cfg.py
def prefix_weight(self, xs):
    "Total weight of all derivations that have `xs` as a prefix."
    return self.prefix_grammar(xs)

rename(f)

Return a new grammar that is the result of applying f to each nonterminal.

Parameters:

Name Type Description Default
f

Function to rename nonterminals

required

Returns:

Type Description

A new CFG with renamed nonterminals

Source code in genlm/grammar/cfg.py
def rename(self, f):
    """
    Return a new grammar that is the result of applying f to each nonterminal.

    Args:
        f: Function to rename nonterminals

    Returns:
        A new CFG with renamed nonterminals
    """
    new = self.spawn(S=f(self.S))
    for r in self:
        new.add(
            r.w, f(r.head), *((y if self.is_terminal(y) else f(y) for y in r.body))
        )
    return new

renumber()

Rename nonterminals to integers.

Returns:

Type Description

A new CFG with integer nonterminals

Source code in genlm/grammar/cfg.py
def renumber(self):
    """
    Rename nonterminals to integers.

    Returns:
        A new CFG with integer nonterminals
    """
    i = Integerizer()
    max_v = max((x for x in self.V if isinstance(x, int)), default=0)
    return self.rename(lambda x: i(x) + max_v + 1)

separate_start()

Ensure that the start symbol does not appear on the RHS of any rule.

Returns:

Type Description

A new CFG with start symbol only on LHS

Source code in genlm/grammar/cfg.py
def separate_start(self):
    """
    Ensure that the start symbol does not appear on the RHS of any rule.

    Returns:
        A new CFG with start symbol only on LHS
    """
    # create a new start symbol if the current one appears on the rhs of any existing rule
    if self.S in {y for r in self for y in r.body}:
        S = _gen_nt(self.S)
        new = self.spawn(S=S)
        # preterminal rules
        new.add(self.R.one, S, self.S)
        for r in self:
            new.add(r.w, r.head, *r.body)
        return new
    else:
        return self

separate_terminals()

Ensure that each terminal is produced by a preterminal rule.

Returns:

Type Description

A new CFG with terminals only in preterminal rules

Source code in genlm/grammar/cfg.py
def separate_terminals(self):
    """
    Ensure that each terminal is produced by a preterminal rule.

    Returns:
        A new CFG with terminals only in preterminal rules
    """
    one = self.R.one
    new = self.spawn()

    _preterminal = {}

    def preterminal(x):
        y = _preterminal.get(x)
        if y is None:
            y = new.add(one, _gen_nt(), x)
            _preterminal[x] = y
        return y

    for r in self:
        if len(r.body) == 1 and self.is_terminal(r.body[0]):
            new.add(r.w, r.head, *r.body)
        else:
            new.add(
                r.w,
                r.head,
                *(
                    (preterminal(y).head if self.is_terminal(y) else y)
                    for y in r.body
                ),
            )

    return new

spawn(*, R=None, S=None, V=None)

Create an empty grammar with the same R, S, and V.

Parameters:

Name Type Description Default
R

Optional new semiring

None
S

Optional new start symbol

None
V

Optional new vocabulary

None

Returns:

Type Description

A new empty CFG with specified parameters

Source code in genlm/grammar/cfg.py
def spawn(self, *, R=None, S=None, V=None):
    """
    Create an empty grammar with the same R, S, and V.

    Args:
        R: Optional new semiring
        S: Optional new start symbol
        V: Optional new vocabulary

    Returns:
        A new empty CFG with specified parameters
    """
    return self.__class__(
        R=self.R if R is None else R,
        S=self.S if S is None else S,
        V=set(self.V) if V is None else V,
    )

to_bytes()

Convert terminal symbols from strings to bytes representation.

This method creates a new grammar where all terminal string symbols are converted to their UTF-8 byte representation. Non-terminal symbols are preserved as-is.

Returns:

Name Type Description
CFG

A new grammar with byte terminal symbols

Raises:

Type Description
ValueError

If a terminal symbol is not a string

Source code in genlm/grammar/cfg.py
def to_bytes(self):
    """Convert terminal symbols from strings to bytes representation.

    This method creates a new grammar where all terminal string symbols are
    converted to their UTF-8 byte representation. Non-terminal symbols are
    preserved as-is.

    Returns:
        CFG: A new grammar with byte terminal symbols

    Raises:
        ValueError: If a terminal symbol is not a string
    """
    new = self.spawn(S=self.S, R=self.R, V=set())

    for r in self:
        new_body = []
        for x in r.body:
            if self.is_terminal(x):
                if not isinstance(x, str):
                    raise ValueError(f"unsupported terminal type: {type(x)}")
                bs = list(x.encode("utf-8"))
                for b in bs:
                    new.V.add(b)
                new_body.extend(bs)
            else:
                new_body.append(x)
        new.add(r.w, r.head, *new_body)

    return new

treesum(**kwargs)

Total weight of the start symbol.

Returns:

Type Description

The total weight of all derivations from the start symbol

Source code in genlm/grammar/cfg.py
def treesum(self, **kwargs):
    """
    Total weight of the start symbol.

    Returns:
        The total weight of all derivations from the start symbol
    """
    return self.agenda(**kwargs)[self.S]

trim(bottomup_only=False)

Return an equivalent grammar with no dead or useless nonterminals or rules.

Parameters:

Name Type Description Default
bottomup_only

If True, only remove non-generating nonterminals

False

Returns:

Type Description

A new trimmed CFG

Source code in genlm/grammar/cfg.py
def trim(self, bottomup_only=False):
    """
    Return an equivalent grammar with no dead or useless nonterminals or rules.

    Args:
        bottomup_only: If True, only remove non-generating nonterminals

    Returns:
        A new trimmed CFG
    """
    if self._trim_cache[bottomup_only] is not None:
        return self._trim_cache[bottomup_only]

    C = set(self.V)
    C.update(e.head for e in self.rules if len(e.body) == 0)

    incoming = defaultdict(list)
    outgoing = defaultdict(list)
    for e in self:
        incoming[e.head].append(e)
        for b in e.body:
            outgoing[b].append(e)

    agenda = set(C)
    while agenda:
        x = agenda.pop()
        for e in outgoing[x]:
            if all((b in C) for b in e.body):
                if e.head not in C:
                    C.add(e.head)
                    agenda.add(e.head)

    if bottomup_only:
        val = self._trim(C)
        self._trim_cache[bottomup_only] = val
        val._trim_cache[bottomup_only] = val
        return val

    T = {self.S}
    agenda.update(T)
    while agenda:
        x = agenda.pop()
        for e in incoming[x]:
            # assert e.head in T
            for b in e.body:
                if b not in T and b in C:
                    T.add(b)
                    agenda.add(b)

    val = self._trim(T)
    self._trim_cache[bottomup_only] = val
    val._trim_cache[bottomup_only] = val
    return val

truncate_length(max_length)

Transform this grammar so that it only generates strings with length ≤ max_length.

Source code in genlm/grammar/cfg.py
def truncate_length(self, max_length):
    "Transform this grammar so that it only generates strings with length ≤ `max_length`."
    from genlm.grammar import WFSA

    m = WFSA(self.R)
    m.add_I(0, self.R.one)
    m.add_F(0, self.R.one)
    for t in range(max_length):
        for x in self.V:
            m.add_arc(t, x, t + 1, self.R.one)
        m.add_F(t + 1, self.R.one)
    return self @ m

unarycycleremove(trim=True)

Return an equivalent grammar with no unary cycles.

Parameters:

Name Type Description Default
trim

If True, trim the resulting grammar

True

Returns:

Type Description

A new CFG without unary cycles

Source code in genlm/grammar/cfg.py
def unarycycleremove(self, trim=True):
    """
    Return an equivalent grammar with no unary cycles.

    Args:
        trim: If True, trim the resulting grammar

    Returns:
        A new CFG without unary cycles
    """

    def bot(x):
        return x if x in acyclic else (x, "bot")

    G = self._unary_graph()

    new = self.spawn(S=self.S)

    bucket = G.buckets

    acyclic = set()
    for nodes, _ in G.Blocks:
        if len(nodes) == 1:
            [X] = nodes
            if G[X, X] == self.R.zero:
                acyclic.add(X)

    # run Lehmann's on each cylical SCC
    for nodes, W in G.Blocks:
        if len(nodes) == 1:
            [X] = nodes
            if X in acyclic:
                continue

        for X1, X2 in W:
            new.add(W[X1, X2], X1, bot(X2))

    for r in self:
        if len(r.body) == 1 and bucket.get(r.body[0]) == bucket[r.head]:
            continue
        new.add(r.w, bot(r.head), *r.body)

    # TODO: figure out how to ensure that the new grammar is trimmed by
    # construction (assuming the input grammar was trim).
    if trim:
        new = new.trim()

    return new

unaryremove()

Return an equivalent grammar with no unary rules.

Returns:

Type Description

A new CFG without unary rules

Source code in genlm/grammar/cfg.py
def unaryremove(self):
    """
    Return an equivalent grammar with no unary rules.

    Returns:
        A new CFG without unary rules
    """
    W = self._unary_graph().closure_scc_based()
    # W = self._unary_graph().closure_reference()

    new = self.spawn()
    for r in self:
        if len(r.body) == 1 and self.is_nonterminal(r.body[0]):
            continue
        for Y in self.N:
            new.add(W[Y, r.head] * r.w, Y, *r.body)

    return new

unfold(i, k)

Apply the unfolding transformation to rule i and subgoal k.

Parameters:

Name Type Description Default
i

Index of rule to unfold

required
k

Index of subgoal in rule body

required

Returns:

Type Description

A new CFG with the rule unfolded

Source code in genlm/grammar/cfg.py
def unfold(self, i, k):
    """
    Apply the unfolding transformation to rule i and subgoal k.

    Args:
        i: Index of rule to unfold
        k: Index of subgoal in rule body

    Returns:
        A new CFG with the rule unfolded
    """
    assert isinstance(i, int) and isinstance(k, int)
    s = self.rules[i]
    assert self.is_nonterminal(s.body[k])

    new = self.spawn()
    for j, r in enumerate(self):
        if j != i:
            new.add(r.w, r.head, *r.body)

    for r in self.rhs[s.body[k]]:
        new.add(s.w * r.w, s.head, *s.body[:k], *r.body, *s.body[k + 1 :])

    return new

Derivation

A derivation tree in a context-free grammar.

Attributes:

Name Type Description
r Rule

The rule used at this node, or None

x

The symbol at this node

ys

Child nodes of the derivation

Source code in genlm/grammar/cfg.py
class Derivation:
    """A derivation tree in a context-free grammar.

    Attributes:
        r (Rule): The rule used at this node, or None
        x: The symbol at this node
        ys: Child nodes of the derivation
    """

    def __init__(self, r, x, *ys):
        """Initialize a Derivation.

        Args:
            r (Rule): The rule used at this node, or None
            x: The symbol at this node
            *ys: Child nodes of the derivation
        """
        assert isinstance(r, Rule) or r is None
        self.r = r
        self.x = x
        self.ys = ys

    def __hash__(self):
        return hash((self.r, self.x, self.ys))

    def __eq__(self, other):
        return (self.r, self.x, self.ys) == (other.r, other.x, other.ys)

    def __repr__(self):
        open_p = colors.dark.white % "("
        close_p = colors.dark.white % ")"
        children = " ".join(str(y) for y in self.ys)
        return f"{open_p}{self.x} {children}{close_p}"

    def weight(self):
        """Compute the weight of this derivation.

        Returns:
            The weight of this derivation, computed by multiplying the rule weight
            with the weights of all child derivations.
        """
        W = self.r.w
        for y in self.ys:
            if isinstance(y, Derivation):
                W *= y.weight()
        return W

    def Yield(self):
        """Return the yield (terminal string) of this derivation.

        Returns:
            tuple: The sequence of terminal symbols at the leaves of this derivation tree.
        """
        if isinstance(self, Derivation):
            return tuple(w for y in self.ys for w in Derivation.Yield(y))
        else:
            return (self,)

    def _repr_html_(self):
        return self.to_nltk()._repr_svg_()

Yield()

Return the yield (terminal string) of this derivation.

Returns:

Name Type Description
tuple

The sequence of terminal symbols at the leaves of this derivation tree.

Source code in genlm/grammar/cfg.py
def Yield(self):
    """Return the yield (terminal string) of this derivation.

    Returns:
        tuple: The sequence of terminal symbols at the leaves of this derivation tree.
    """
    if isinstance(self, Derivation):
        return tuple(w for y in self.ys for w in Derivation.Yield(y))
    else:
        return (self,)

__init__(r, x, *ys)

Initialize a Derivation.

Parameters:

Name Type Description Default
r Rule

The rule used at this node, or None

required
x

The symbol at this node

required
*ys

Child nodes of the derivation

()
Source code in genlm/grammar/cfg.py
def __init__(self, r, x, *ys):
    """Initialize a Derivation.

    Args:
        r (Rule): The rule used at this node, or None
        x: The symbol at this node
        *ys: Child nodes of the derivation
    """
    assert isinstance(r, Rule) or r is None
    self.r = r
    self.x = x
    self.ys = ys

weight()

Compute the weight of this derivation.

Returns:

Type Description

The weight of this derivation, computed by multiplying the rule weight

with the weights of all child derivations.

Source code in genlm/grammar/cfg.py
def weight(self):
    """Compute the weight of this derivation.

    Returns:
        The weight of this derivation, computed by multiplying the rule weight
        with the weights of all child derivations.
    """
    W = self.r.w
    for y in self.ys:
        if isinstance(y, Derivation):
            W *= y.weight()
    return W

NotNull

A non-null nonterminal, which is used for the nullary elimination. Denotes a non-terminal that cannot yield an empty string.

Source code in genlm/grammar/cfg.py
class NotNull:
    """A non-null nonterminal, which is used for the nullary elimination.
    Denotes a non-terminal that cannot yield an empty string."""

    __slots__ = ("x",)

    def __init__(self, x):
        self.x = x

    def __repr__(self):
        return f"{self.x}"  # pragma: no cover

    def __hash__(self):
        return hash(("NotNull", self.x))

    def __eq__(self, other):
        return isinstance(other, NotNull) and self.x == other.x

Other

Generates a novel 'other' nonterminal, which may be used in various grammar transformations.

Source code in genlm/grammar/cfg.py
class Other:
    """Generates a novel 'other' nonterminal, which may be used in
    various grammar transformations."""

    __slots__ = ("x",)

    def __init__(self, x):
        self.x = x

    def __repr__(self):
        return f"{self.x}"

    def __hash__(self):
        return hash(("Other", self.x))

    def __eq__(self, other):
        return isinstance(other, Other) and self.x == other.x

Rule

A weighted production rule in a context-free grammar.

Attributes:

Name Type Description
w

Weight of the rule

head

Left-hand side nonterminal symbol

body

Right-hand side sequence of symbols

Source code in genlm/grammar/cfg.py
class Rule:
    """A weighted production rule in a context-free grammar.

    Attributes:
        w: Weight of the rule
        head: Left-hand side nonterminal symbol
        body: Right-hand side sequence of symbols
    """

    def __init__(self, w, head, body):
        """Initialize a Rule.

        Args:
            w: Weight of the rule
            head: Left-hand side nonterminal symbol
            body: Right-hand side sequence of symbols
        """
        self.w = w
        self.head = head
        self.body = body
        self._hash = hash((head, body))

    def __eq__(self, other):
        return (
            isinstance(other, Rule)
            and self.w == other.w
            and self._hash == other._hash
            and other.head == self.head
            and other.body == self.body
        )

    def __hash__(self):
        return self._hash

    def __repr__(self):
        return f"{self.w}: {self.head}{' '.join(map(str, self.body))}"

__init__(w, head, body)

Initialize a Rule.

Parameters:

Name Type Description Default
w

Weight of the rule

required
head

Left-hand side nonterminal symbol

required
body

Right-hand side sequence of symbols

required
Source code in genlm/grammar/cfg.py
def __init__(self, w, head, body):
    """Initialize a Rule.

    Args:
        w: Weight of the rule
        head: Left-hand side nonterminal symbol
        body: Right-hand side sequence of symbols
    """
    self.w = w
    self.head = head
    self.body = body
    self._hash = hash((head, body))

Slash

A slash nonterminal, which is used for the derivative grammar.

Source code in genlm/grammar/cfg.py
class Slash:
    """A slash nonterminal, which is used for the derivative grammar."""

    __slots__ = ("Y", "Z", "i")

    def __init__(self, Y, Z, i):
        self.Y = Y
        self.Z = Z
        self.i = i

    def __repr__(self):
        return f"{self.Y}/{self.Z}@{self.i}"  # pragma: no cover

    def __hash__(self) -> int:
        return hash((self.Y, self.Z, self.i))

    def __eq__(self, other):
        return (
            isinstance(other, Slash)
            and self.Y == other.Y
            and self.Z == other.Z
            and self.i == other.i
        )

prefix_transducer(R, V)

Construct the prefix transducer over semiring R and alphabet V.

Source code in genlm/grammar/cfg.py
def prefix_transducer(R, V):
    "Construct the prefix transducer over semiring `R` and alphabet `V`."
    P = FST(R)
    P.add_I(0, R.one)
    P.add_I(1, R.one)
    for x in V:
        P.add_arc(0, (x, x), 0, R.one)
        P.add_arc(0, (x, x), 1, R.one)
        P.add_arc(1, (x, EPSILON), 1, R.one)
    P.add_F(1, R.one)
    return P