Skip to content

Commit 7096acd

Browse files
dv8081pavoljuhasNoureldinYosri
authored
Fix Moment.resolve_parameter for constant sympy expressions (#6794)
Problem: `Moment._resolve_parameters_` might keep constant sympy expression if its resolved value is numerically equal. Solution: Check if parameterization changed for the resolved operation and if so replace the original operation with resolved (even if equal). Fixes #6778 --------- Co-authored-by: Pavol Juhas <[email protected]> Co-authored-by: Noureldin <[email protected]>
1 parent b35382f commit 7096acd

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

cirq-core/cirq/circuits/moment.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,11 @@ def _resolve_parameters_(
275275
resolved_ops: List['cirq.Operation'] = []
276276
for op in self:
277277
resolved_op = protocols.resolve_parameters(op, resolver, recursive)
278-
if resolved_op != op:
279-
changed = True
278+
changed = (
279+
changed
280+
or resolved_op != op
281+
or (protocols.is_parameterized(op) and not protocols.is_parameterized(resolved_op))
282+
)
280283
resolved_ops.append(resolved_op)
281284
if not changed:
282285
return self

cirq-core/cirq/circuits/moment_test.py

+8
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,14 @@ def test_resolve_parameters():
294294
moment = cirq.Moment(cirq.X(a) ** sympy.Symbol('v'), cirq.Y(b) ** sympy.Symbol('w'))
295295
resolved_moment = cirq.resolve_parameters(moment, cirq.ParamResolver({'v': 0.1, 'w': 0.2}))
296296
assert resolved_moment == cirq.Moment(cirq.X(a) ** 0.1, cirq.Y(b) ** 0.2)
297+
# sympy constant is resolved to a Python number
298+
moment = cirq.Moment(cirq.Rz(rads=sympy.pi).on(a))
299+
resolved_moment = cirq.resolve_parameters(moment, {'pi': np.pi})
300+
assert resolved_moment == cirq.Moment(cirq.Rz(rads=np.pi).on(a))
301+
resolved_gate = resolved_moment.operations[0].gate
302+
assert not isinstance(resolved_gate.exponent, sympy.Basic)
303+
assert isinstance(resolved_gate.exponent, float)
304+
assert not cirq.is_parameterized(resolved_moment)
297305

298306

299307
def test_resolve_parameters_no_change():

0 commit comments

Comments
 (0)