File tree 2 files changed +16
-0
lines changed
2 files changed +16
-0
lines changed Original file line number Diff line number Diff line change 5
5
SemiMarkovCRF ,
6
6
DependencyCRF ,
7
7
TreeCRF ,
8
+ SentCFG ,
8
9
)
9
10
from .cky_crf import CKY_CRF
10
11
from .deptree import DepTree
42
43
SemiMarkovCRF ,
43
44
DependencyCRF ,
44
45
TreeCRF ,
46
+ SentCFG ,
45
47
]
Original file line number Diff line number Diff line change 2
2
from torch .distributions .distribution import Distribution
3
3
from torch .distributions .utils import lazy_property
4
4
from .linearchain import LinearChain
5
+ from .cky import CKY
5
6
from .semimarkov import SemiMarkov
6
7
from .deptree import DepTree
7
8
from .cky_crf import CKY_CRF
@@ -100,3 +101,16 @@ class DependencyCRF(StructDistribution):
100
101
101
102
class TreeCRF (StructDistribution ):
102
103
struct = CKY_CRF
104
+
105
+
106
+ class SentCFG (StructDistribution ):
107
+ struct = CKY
108
+
109
+ def __init__ (self , log_potentials , lengths = None ):
110
+ batch_shape = log_potentials [0 ].shape [:1 ]
111
+ event_shape = log_potentials [0 ].shape [1 :]
112
+ self .log_potentials = log_potentials
113
+ self .lengths = lengths
114
+ super (StructDistribution , self ).__init__ (
115
+ batch_shape = batch_shape , event_shape = event_shape
116
+ )
You can’t perform that action at this time.
0 commit comments