Skip to content

Commit d78f771

Browse files
author
Sasha
committed
.
1 parent 7ad4fc0 commit d78f771

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

torch_struct/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
SemiMarkovCRF,
66
DependencyCRF,
77
TreeCRF,
8+
SentCFG,
89
)
910
from .cky_crf import CKY_CRF
1011
from .deptree import DepTree
@@ -42,4 +43,5 @@
4243
SemiMarkovCRF,
4344
DependencyCRF,
4445
TreeCRF,
46+
SentCFG,
4547
]

torch_struct/distributions.py

+14
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from torch.distributions.distribution import Distribution
33
from torch.distributions.utils import lazy_property
44
from .linearchain import LinearChain
5+
from .cky import CKY
56
from .semimarkov import SemiMarkov
67
from .deptree import DepTree
78
from .cky_crf import CKY_CRF
@@ -100,3 +101,16 @@ class DependencyCRF(StructDistribution):
100101

101102
class TreeCRF(StructDistribution):
102103
struct = CKY_CRF
104+
105+
106+
class SentCFG(StructDistribution):
107+
struct = CKY
108+
109+
def __init__(self, log_potentials, lengths=None):
110+
batch_shape = log_potentials[0].shape[:1]
111+
event_shape = log_potentials[0].shape[1:]
112+
self.log_potentials = log_potentials
113+
self.lengths = lengths
114+
super(StructDistribution, self).__init__(
115+
batch_shape=batch_shape, event_shape=event_shape
116+
)

0 commit comments

Comments
 (0)