Skip to content

Commit b5e2f5c

Browse files
committed
resolve merge conflicts
1 parent c9fa127 commit b5e2f5c

File tree

3 files changed

+6
-9
lines changed

3 files changed

+6
-9
lines changed

pymc3/tests/test_variational_inference.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,7 @@ def test_discrete_not_allowed():
961961

962962
with pm.Model():
963963
mu = pm.Normal("mu", mu=0, sigma=10, size=3)
964-
z = pm.Categorical("z", p=at.ones(3) / 3, size=len(y))
964+
z = pm.Categorical("z", p=aet.ones(3) / 3, size=len(y))
965965
pm.Normal("y_obs", mu=mu[z], sigma=1.0, observed=y)
966966
with pytest.raises(opvi.ParametrizationError):
967967
pm.fit(n=1) # fails

pymc3/variational/opvi.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -833,9 +833,6 @@ def __init__(
833833
options=None,
834834
**kwargs,
835835
):
836-
# XXX: Needs to be refactored for v4
837-
raise NotImplementedError("This class needs to be refactored for v4")
838-
839836
if local and not self.supports_batched:
840837
raise LocalGroupError("%s does not support local groups" % self.__class__)
841838
if local and rowwise:
@@ -959,7 +956,7 @@ def __init_group__(self, group):
959956
self.group = [get_transformed(var) for var in self.group]
960957

961958
# XXX: This needs to be refactored
962-
# self.ordering = ArrayOrdering([])
959+
self.point_map_info = []
963960
self.replacements = dict()
964961
for var in self.group:
965962
if var.type.numpy_dtype.name in discrete_types:
@@ -975,18 +972,18 @@ def __init_group__(self, group):
975972
# self.ordering.size += None # (np.prod(var.dshape[1:])).astype(int)
976973
if self.local:
977974
# XXX: This needs to be refactored
978-
shape = None # (-1,) + var.dshape[1:]
975+
shape = (-1,) + var.dshape[1:]
979976
else:
980977
# XXX: This needs to be refactored
981-
shape = None # var.dshape
978+
shape = var.dshape
982979
else:
983980
# XXX: This needs to be refactored
984981
# self.ordering.size += None # var.dsize
985982
# XXX: This needs to be refactored
986-
shape = None # var.dshape
983+
shape = var.dshape
987984
# end = self.ordering.size
988985
# XXX: This needs to be refactored
989-
vmap = None # VarMap(var.name, slice(begin, end), shape, var.dtype)
986+
vmap = (var.name, shape, var.dtype)
990987
# self.ordering.vmap.append(vmap)
991988
# self.ordering.by_name[vmap.var] = vmap
992989
vr = self.input[..., vmap.slc].reshape(shape).astype(vmap.dtyp)

pymc3/variational/updates.py

100755100644
File mode changed.

0 commit comments

Comments
 (0)