Skip to content

Commit cc04cfe

Browse files
authored
Recursive subop parameter resolution (quantumlib#5033)
Preserves existing behavior in circuitoperations, where `with_params({a: b, b: a})` just swaps the parameter names and preserves that behavior for subsequent application (we don't change like 613), but we allow optional recursive application for each individual resolution applied (line 614). @95-martin-orion Fixes quantumlib#5016 Closes quantumlib#3619
1 parent 2c6a2f8 commit cc04cfe

File tree

2 files changed

+46
-13
lines changed

2 files changed

+46
-13
lines changed

cirq/circuits/circuit_operation.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -592,17 +592,26 @@ def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'cirq.Circu
592592
return self.with_measurement_key_mapping(key_map)
593593

594594
def with_params(
595-
self, param_values: 'cirq.ParamResolverOrSimilarType'
595+
self, param_values: 'cirq.ParamResolverOrSimilarType', recursive: bool = False
596596
) -> 'cirq.CircuitOperation':
597597
"""Returns a copy of this operation with an updated ParamResolver.
598598
599+
Any existing parameter mappings will have their values updated given
600+
the provided mapping, and any new parameters will be added to the
601+
ParamResolver.
602+
599603
Note that any resulting parameter mappings with no corresponding
600604
parameter in the base circuit will be omitted.
601605
602606
Args:
603607
param_values: A map or ParamResolver able to convert old param
604608
values to new param values. This map will be composed with any
605609
existing ParamResolver via single-step resolution.
610+
recursive: If True, resolves parameter values recursively over the
611+
resolver; otherwise performs a single resolution step. This
612+
behavior applies only to the passed-in mapping, for the current
613+
application. Existing parameters are never resolved recursively
614+
because a->b and b->a needs to be a valid mapping.
606615
607616
Returns:
608617
A copy of this operation with its ParamResolver updated as specified
@@ -611,18 +620,12 @@ def with_params(
611620
new_params = {}
612621
for k in protocols.parameter_symbols(self.circuit):
613622
v = self.param_resolver.value_of(k, recursive=False)
614-
v = protocols.resolve_parameters(v, param_values, recursive=False)
623+
v = protocols.resolve_parameters(v, param_values, recursive=recursive)
615624
if v != k:
616625
new_params[k] = v
617626
return self.replace(param_resolver=new_params)
618627

619-
# TODO: handle recursive parameter resolution gracefully
620628
def _resolve_parameters_(
621629
self, resolver: 'cirq.ParamResolver', recursive: bool
622630
) -> 'cirq.CircuitOperation':
623-
if recursive:
624-
raise ValueError(
625-
'Recursive resolution of CircuitOperation parameters is prohibited. '
626-
'Use "recursive=False" to prevent this error.'
627-
)
628-
return self.with_params(resolver.param_dict)
631+
return self.with_params(resolver.param_dict, recursive)

cirq/circuits/circuit_operation_test.py

+34-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import cirq
2121
from cirq.circuits.circuit_operation import _full_join_string_lists
2222

23-
2423
ALL_SIMULATORS = (
2524
cirq.Simulator(),
2625
cirq.DensityMatrixSimulator(),
@@ -248,9 +247,40 @@ def test_with_params():
248247
== op_with_params
249248
)
250249

251-
# Recursive parameter resolution is rejected.
252-
with pytest.raises(ValueError, match='Use "recursive=False"'):
253-
_ = cirq.resolve_parameters(op_base, cirq.ParamResolver(param_dict))
250+
251+
def test_recursive_params():
252+
q = cirq.LineQubit(0)
253+
a, a2, b, b2 = sympy.symbols('a a2 b b2')
254+
circuitop = cirq.CircuitOperation(
255+
cirq.FrozenCircuit(
256+
cirq.X(q) ** a,
257+
cirq.Z(q) ** b,
258+
),
259+
# Not recursive, a and b are swapped.
260+
param_resolver=cirq.ParamResolver({a: b, b: a}),
261+
)
262+
# Recursive, so a->a2->0 and b->b2->1.
263+
outer_params = {a: a2, a2: 0, b: b2, b2: 1}
264+
resolved = cirq.resolve_parameters(circuitop, outer_params)
265+
# Combined, a->b->b2->1, and b->a->a2->0.
266+
assert resolved.param_resolver.param_dict == {a: 1, b: 0}
267+
268+
# Non-recursive, so a->a2 and b->b2.
269+
resolved = cirq.resolve_parameters(circuitop, outer_params, recursive=False)
270+
# Combined, a->b->b2, and b->a->a2.
271+
assert resolved.param_resolver.param_dict == {a: b2, b: a2}
272+
273+
with pytest.raises(RecursionError):
274+
cirq.resolve_parameters(circuitop, {a: a2, a2: a})
275+
276+
# Non-recursive, so a->b and b->a.
277+
resolved = cirq.resolve_parameters(circuitop, {a: b, b: a}, recursive=False)
278+
# Combined, a->b->a, and b->a->b.
279+
assert resolved.param_resolver.param_dict == {}
280+
281+
# First example should behave like an X when simulated
282+
result = cirq.Simulator().simulate(cirq.Circuit(circuitop), param_resolver=outer_params)
283+
assert np.allclose(result.state_vector(), [0, 1])
254284

255285

256286
@pytest.mark.parametrize('add_measurements', [True, False])

0 commit comments

Comments
 (0)