diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index b0f4aaf9fc..161a8143d6 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -2027,7 +2027,6 @@ def _sum_grad_over_bcasted_dims(x, gx): if gx.broadcastable != x.broadcastable: x_dim_added = gx.ndim - x.ndim x_broad = (True,) * x_dim_added + x.broadcastable - assert sum(gx.broadcastable) <= sum(x_broad) axis_to_sum = [] for i in range(gx.ndim): if gx.broadcastable[i] is False and x_broad[i] is True: @@ -2045,7 +2044,7 @@ def _sum_grad_over_bcasted_dims(x, gx): for i in range(x_dim_added): assert gx.broadcastable[i] gx = gx.dimshuffle(*range(x_dim_added, gx.ndim)) - assert gx.broadcastable == x.broadcastable + assert x.type.is_super(gx.type) return gx diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index d02880f543..aebd60de56 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -12,7 +12,9 @@ from pytensor import function from pytensor.compile import DeepCopyOp, shared from pytensor.compile.io import In +from pytensor.compile.mode import Mode from pytensor.configdefaults import config +from pytensor.gradient import grad from pytensor.graph.op import get_test_value from pytensor.graph.rewriting.utils import is_same_graph from pytensor.printing import pprint @@ -22,6 +24,7 @@ from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import exp, isinf from pytensor.tensor.math import sum as pt_sum +from pytensor.tensor.shape import specify_shape from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -1660,6 +1663,25 @@ def just_numeric_args(a, b): ), ) + def test_grad_broadcastable_specialization(self): + # Make sure gradient does not fail when gx has a more precise static_shape after indexing. + # This is a regression test for a bug reported in + # https://discourse.pymc.io/t/marginalized-mixture-wont-begin-sampling-throws-assertion-error/15969 + + x = vector("x") # Unknown write time shape = (2,) + out = x.zeros_like() + + # Update a subtensor of unknown write time shape = (1,) + out = out[1:].set(exp(x[1:])) + out = specify_shape(out, 2) + gx = grad(out.sum(), x) + + mode = Mode(linker="py", optimizer=None) + np.testing.assert_allclose( + gx.eval({x: [1, 1]}, mode=mode), + [0, np.e], + ) + class TestIncSubtensor1: def setup_method(self):