|
20 | 20 | import cirq
|
21 | 21 | from cirq.circuits.circuit_operation import _full_join_string_lists
|
22 | 22 |
|
23 |
| - |
24 | 23 | ALL_SIMULATORS = (
|
25 | 24 | cirq.Simulator(),
|
26 | 25 | cirq.DensityMatrixSimulator(),
|
@@ -248,9 +247,40 @@ def test_with_params():
|
248 | 247 | == op_with_params
|
249 | 248 | )
|
250 | 249 |
|
251 |
| - # Recursive parameter resolution is rejected. |
252 |
| - with pytest.raises(ValueError, match='Use "recursive=False"'): |
253 |
| - _ = cirq.resolve_parameters(op_base, cirq.ParamResolver(param_dict)) |
| 250 | + |
| 251 | +def test_recursive_params(): |
| 252 | + q = cirq.LineQubit(0) |
| 253 | + a, a2, b, b2 = sympy.symbols('a a2 b b2') |
| 254 | + circuitop = cirq.CircuitOperation( |
| 255 | + cirq.FrozenCircuit( |
| 256 | + cirq.X(q) ** a, |
| 257 | + cirq.Z(q) ** b, |
| 258 | + ), |
| 259 | + # Not recursive, a and b are swapped. |
| 260 | + param_resolver=cirq.ParamResolver({a: b, b: a}), |
| 261 | + ) |
| 262 | + # Recursive, so a->a2->0 and b->b2->1. |
| 263 | + outer_params = {a: a2, a2: 0, b: b2, b2: 1} |
| 264 | + resolved = cirq.resolve_parameters(circuitop, outer_params) |
| 265 | + # Combined, a->b->b2->1, and b->a->a2->0. |
| 266 | + assert resolved.param_resolver.param_dict == {a: 1, b: 0} |
| 267 | + |
| 268 | + # Non-recursive, so a->a2 and b->b2. |
| 269 | + resolved = cirq.resolve_parameters(circuitop, outer_params, recursive=False) |
| 270 | + # Combined, a->b->b2, and b->a->a2. |
| 271 | + assert resolved.param_resolver.param_dict == {a: b2, b: a2} |
| 272 | + |
| 273 | + with pytest.raises(RecursionError): |
| 274 | + cirq.resolve_parameters(circuitop, {a: a2, a2: a}) |
| 275 | + |
| 276 | + # Non-recursive, so a->b and b->a. |
| 277 | + resolved = cirq.resolve_parameters(circuitop, {a: b, b: a}, recursive=False) |
| 278 | + # Combined, a->b->a, and b->a->b. |
| 279 | + assert resolved.param_resolver.param_dict == {} |
| 280 | + |
| 281 | + # First example should behave like an X when simulated |
| 282 | + result = cirq.Simulator().simulate(cirq.Circuit(circuitop), param_resolver=outer_params) |
| 283 | + assert np.allclose(result.state_vector(), [0, 1]) |
254 | 284 |
|
255 | 285 |
|
256 | 286 | @pytest.mark.parametrize('add_measurements', [True, False])
|
|
0 commit comments