Skip to content

Commit ec84a05

Browse files
authored
Preserve subcircuits passed to [Frozen]Circuit.from_moments (#6320)
Review: @tanujkhattar
1 parent 96b3842 commit ec84a05

File tree

3 files changed

+49
-9
lines changed

3 files changed

+49
-9
lines changed

cirq-core/cirq/circuits/circuit.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,26 @@ def from_moments(cls: Type[CIRCUIT_TYPE], *moments: 'cirq.OP_TREE') -> CIRCUIT_T
149149
"""Create a circuit from moment op trees.
150150
151151
Args:
152-
*moments: Op tree for each moment.
152+
*moments: Op tree for each moment. If an op tree is a moment, it
153+
will be included directly in the new circuit. If an op tree is
154+
a circuit, it will be frozen, wrapped in a CircuitOperation, and
155+
included in its own moment in the new circuit. Otherwise, the
156+
op tree will be passed to `cirq.Moment` to create a new moment
157+
which is then included in the new circuit. Note that in the
158+
latter case we have the normal restriction that operations in a
159+
moment must be applied to disjoint sets of qubits.
153160
"""
154-
return cls._from_moments(
155-
moment if isinstance(moment, Moment) else Moment(moment) for moment in moments
156-
)
161+
return cls._from_moments(cls._make_moments(moments))
162+
163+
@staticmethod
164+
def _make_moments(moments: Iterable['cirq.OP_TREE']) -> Iterator['cirq.Moment']:
165+
for m in moments:
166+
if isinstance(m, Moment):
167+
yield m
168+
elif isinstance(m, AbstractCircuit):
169+
yield Moment(m.freeze().to_op())
170+
else:
171+
yield Moment(m)
157172

158173
@classmethod
159174
@abc.abstractmethod

cirq-core/cirq/circuits/circuit_test.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import sympy
2424

2525
import cirq
26-
import cirq.testing
2726
from cirq import circuits
2827
from cirq import ops
2928
from cirq.testing.devices import ValidatingTestDevice
@@ -72,19 +71,32 @@ def validate_moment(self, moment):
7271

7372
def test_from_moments():
7473
a, b, c, d = cirq.LineQubit.range(4)
75-
assert cirq.Circuit.from_moments(
74+
moment = cirq.Moment(cirq.Z(a), cirq.Z(b))
75+
subcircuit = cirq.FrozenCircuit.from_moments(cirq.X(c), cirq.Y(d))
76+
circuit = cirq.Circuit.from_moments(
77+
moment,
78+
subcircuit,
7679
[cirq.X(a), cirq.Y(b)],
7780
[cirq.X(c)],
7881
[],
7982
cirq.Z(d),
8083
[cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')],
81-
) == cirq.Circuit(
84+
)
85+
assert circuit == cirq.Circuit(
86+
cirq.Moment(cirq.Z(a), cirq.Z(b)),
87+
cirq.Moment(
88+
cirq.CircuitOperation(
89+
cirq.FrozenCircuit(cirq.Moment(cirq.X(c)), cirq.Moment(cirq.Y(d)))
90+
)
91+
),
8292
cirq.Moment(cirq.X(a), cirq.Y(b)),
8393
cirq.Moment(cirq.X(c)),
8494
cirq.Moment(),
8595
cirq.Moment(cirq.Z(d)),
8696
cirq.Moment(cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')),
8797
)
98+
assert circuit[0] is moment
99+
assert circuit[1].operations[0].circuit is subcircuit
88100

89101

90102
def test_alignment():

cirq-core/cirq/circuits/frozen_circuit_test.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,32 @@
2424

2525
def test_from_moments():
2626
a, b, c, d = cirq.LineQubit.range(4)
27-
assert cirq.FrozenCircuit.from_moments(
27+
moment = cirq.Moment(cirq.Z(a), cirq.Z(b))
28+
subcircuit = cirq.FrozenCircuit.from_moments(cirq.X(c), cirq.Y(d))
29+
circuit = cirq.FrozenCircuit.from_moments(
30+
moment,
31+
subcircuit,
2832
[cirq.X(a), cirq.Y(b)],
2933
[cirq.X(c)],
3034
[],
3135
cirq.Z(d),
3236
[cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')],
33-
) == cirq.FrozenCircuit(
37+
)
38+
assert circuit == cirq.FrozenCircuit(
39+
cirq.Moment(cirq.Z(a), cirq.Z(b)),
40+
cirq.Moment(
41+
cirq.CircuitOperation(
42+
cirq.FrozenCircuit(cirq.Moment(cirq.X(c)), cirq.Moment(cirq.Y(d)))
43+
)
44+
),
3445
cirq.Moment(cirq.X(a), cirq.Y(b)),
3546
cirq.Moment(cirq.X(c)),
3647
cirq.Moment(),
3748
cirq.Moment(cirq.Z(d)),
3849
cirq.Moment(cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')),
3950
)
51+
assert circuit[0] is moment
52+
assert circuit[1].operations[0].circuit is subcircuit
4053

4154

4255
def test_freeze_and_unfreeze():

0 commit comments

Comments
 (0)