Skip to content

Commit 055af4d

Browse files
committed
Fix too strict type check in _sum_grad_over_bcasted_dims
1 parent a377c22 commit 055af4d

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

pytensor/tensor/subtensor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2027,7 +2027,6 @@ def _sum_grad_over_bcasted_dims(x, gx):
20272027
if gx.broadcastable != x.broadcastable:
20282028
x_dim_added = gx.ndim - x.ndim
20292029
x_broad = (True,) * x_dim_added + x.broadcastable
2030-
assert sum(gx.broadcastable) <= sum(x_broad)
20312030
axis_to_sum = []
20322031
for i in range(gx.ndim):
20332032
if gx.broadcastable[i] is False and x_broad[i] is True:
@@ -2045,7 +2044,7 @@ def _sum_grad_over_bcasted_dims(x, gx):
20452044
for i in range(x_dim_added):
20462045
assert gx.broadcastable[i]
20472046
gx = gx.dimshuffle(*range(x_dim_added, gx.ndim))
2048-
assert gx.broadcastable == x.broadcastable
2047+
assert x.type.is_super(gx.type)
20492048
return gx
20502049

20512050

tests/tensor/test_subtensor.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,23 @@
1212
from pytensor import function
1313
from pytensor.compile import DeepCopyOp, shared
1414
from pytensor.compile.io import In
15+
from pytensor.compile.mode import Mode
1516
from pytensor.configdefaults import config
17+
from pytensor.gradient import grad
1618
from pytensor.graph.op import get_test_value
1719
from pytensor.graph.rewriting.utils import is_same_graph
1820
from pytensor.printing import pprint
1921
from pytensor.scalar.basic import as_scalar, int16
20-
from pytensor.tensor import as_tensor, get_vector_length, vectorize
22+
from pytensor.tensor import (
23+
as_tensor,
24+
get_vector_length,
25+
vectorize,
26+
)
2127
from pytensor.tensor.blockwise import Blockwise
2228
from pytensor.tensor.elemwise import DimShuffle
2329
from pytensor.tensor.math import exp, isinf
2430
from pytensor.tensor.math import sum as pt_sum
31+
from pytensor.tensor.shape import specify_shape
2532
from pytensor.tensor.subtensor import (
2633
AdvancedIncSubtensor,
2734
AdvancedIncSubtensor1,
@@ -1660,6 +1667,25 @@ def just_numeric_args(a, b):
16601667
),
16611668
)
16621669

1670+
def test_grad_broadcastable_specialization(self):
1671+
# Make sure gradient does not fail when gx has a more precise static_shape after indexing.
1672+
# This is a regression test for a bug reported in
1673+
# https://discourse.pymc.io/t/marginalized-mixture-wont-begin-sampling-throws-assertion-error/15969/10?u=ricardov94
1674+
1675+
x = vector("x") # Unknown write time shape = (2,)
1676+
out = x.zeros_like()
1677+
1678+
# Update a subtensor of unknown write time shape = (1,)
1679+
out = out[1:].set(exp(x[1:]))
1680+
out = specify_shape(out, 2)
1681+
gx = grad(out.sum(), x)
1682+
1683+
mode = Mode(linker="py", optimizer=None)
1684+
np.testing.assert_allclose(
1685+
gx.eval({x: [1, 1]}, mode=mode),
1686+
[0, np.e],
1687+
)
1688+
16631689

16641690
class TestIncSubtensor1:
16651691
def setup_method(self):

0 commit comments

Comments
 (0)