diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 6b4ca7570d..29b2043f6d 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -622,9 +622,9 @@ def convert_string_keys_to_variables(inputs_to_values) -> dict["Variable", Any]: if isinstance(key, str): matching_vars = get_var_by_name([self], key) if not matching_vars: - raise Exception(f"{key} not found in graph") + raise ValueError(f"{key} not found in graph") elif len(matching_vars) > 1: - raise Exception(f"Found multiple variables with name {key}") + raise ValueError(f"Found multiple variables with name {key}") new_input_to_values[matching_vars[0]] = value else: new_input_to_values[key] = value diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index b230f035cc..16df2d1b08 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -699,7 +699,10 @@ def local_div_switch_sink(fgraph, node): # will point to the new division op. copy_stack_trace(node.outputs, fdiv) - fct = switch(switch_cond, zero_switch_input, fdiv) + if branch == 0: + fct = switch(switch_cond, zero_switch_input, fdiv) + else: + fct = switch(switch_cond, fdiv, zero_switch_input) # Tell debug_mode than the output is correct, even if nan disappear fct.tag.values_eq_approx = values_eq_approx_remove_nan diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index e4a08cdf81..1160562e62 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -2163,7 +2163,7 @@ def test_local_mul_div_switch_sink_cast(self, op, rewrite): # The zero branch upcasts the output, so we can't ignore its dtype zero_branch = constant(np.array(0, dtype="float64"), name="zero_branch") other_branch = scalar("other_branch", dtype="float32") - outer_var = scalar("mul_var", dtype="bool") + outer_var = scalar("outer_var", dtype="bool") out = op(switch(cond, zero_branch, other_branch), outer_var) fgraph = FunctionGraph(outputs=[out], clone=False) @@ -2173,6 +2173,27 @@ def test_local_mul_div_switch_sink_cast(self, op, rewrite): expected_out = switch(cond, zero_branch, op(other_branch, outer_var)) assert equal_computations([new_out], [expected_out]) + @pytest.mark.parametrize( + "op, rewrite", [(mul, local_mul_switch_sink), (true_div, local_div_switch_sink)] + ) + def test_local_mul_div_switch_sink_branch_order(self, op, rewrite): + cond = scalar("cond", dtype="bool") + zero_branch = constant(np.array(0.0, dtype="float64"), "zero_branch") + other_branch = scalar("other_branch", dtype="float64") + outer_var = scalar("outer_var", dtype="float64") + + left = op(switch(cond, zero_branch, other_branch), outer_var) + right = op(switch(cond, other_branch, zero_branch), outer_var) + fgraph = FunctionGraph(outputs=[left, right], clone=False) + [new_left] = rewrite.transform(fgraph, left.owner) + [new_right] = rewrite.transform(fgraph, right.owner) + + expected_left = switch(cond, zero_branch, op(other_branch, outer_var)) + expected_right = switch(cond, op(other_branch, outer_var), zero_branch) + assert equal_computations( + [new_left, new_right], [expected_left, expected_right] + ) + @pytest.mark.skipif( config.cxx == "",