Skip to content

Commit 1cd3aca

Browse files
authored
Add Moment.from_ops to more efficiently construct moments (quantumlib#6078)
Review: @tanujkhattar
1 parent d3616b7 commit 1cd3aca

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

cirq/circuits/moment.py

+29-5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
AbstractSet,
2020
Any,
2121
Callable,
22+
cast,
2223
Dict,
2324
FrozenSet,
2425
Iterable,
@@ -80,17 +81,26 @@ class Moment:
8081
are no such operations, returns an empty Moment.
8182
"""
8283

83-
def __init__(self, *contents: 'cirq.OP_TREE') -> None:
84+
def __init__(self, *contents: 'cirq.OP_TREE', _flatten_contents: bool = True) -> None:
8485
"""Constructs a moment with the given operations.
8586
8687
Args:
8788
contents: The operations applied within the moment.
8889
Will be flattened and frozen into a tuple before storing.
90+
_flatten_contents: If True, use flatten_to_ops to convert
91+
the OP_TREE of contents into a tuple of Operation. If False,
92+
we skip flattening and assume that contents already consists
93+
of individual operations. This is used internally by helper
94+
methods to avoid unnecessary validation.
8995
9096
Raises:
9197
ValueError: A qubit appears more than once.
9298
"""
93-
self._operations = tuple(op_tree.flatten_to_ops(contents))
99+
self._operations = (
100+
tuple(op_tree.flatten_to_ops(contents))
101+
if _flatten_contents
102+
else cast(Tuple['cirq.Operation'], contents)
103+
)
94104
self._sorted_operations: Optional[Tuple['cirq.Operation', ...]] = None
95105

96106
# An internal dictionary to support efficient operation access by qubit.
@@ -106,6 +116,20 @@ def __init__(self, *contents: 'cirq.OP_TREE') -> None:
106116
self._measurement_key_objs: Optional[FrozenSet['cirq.MeasurementKey']] = None
107117
self._control_keys: Optional[FrozenSet['cirq.MeasurementKey']] = None
108118

119+
@classmethod
120+
def from_ops(cls, *ops: 'cirq.Operation') -> 'cirq.Moment':
121+
"""Construct a Moment from the given operations.
122+
123+
This avoids calling `flatten_to_ops` in the moment constructor, which
124+
results in better performance in cases where the contents of the moment
125+
are already in the form of a sequence of operations rather than an
126+
arbitrary OP_TREE.
127+
128+
Args:
129+
*ops: Operations to include in the Moment.
130+
"""
131+
return cls(*ops, _flatten_contents=False)
132+
109133
@property
110134
def operations(self) -> Tuple['cirq.Operation', ...]:
111135
return self._operations
@@ -164,7 +188,7 @@ def with_operation(self, operation: 'cirq.Operation') -> 'cirq.Moment':
164188
raise ValueError(f'Overlapping operations: {operation}')
165189

166190
# Use private variables to facilitate a quick copy.
167-
m = Moment()
191+
m = Moment(_flatten_contents=False)
168192
m._operations = self._operations + (operation,)
169193
m._sorted_operations = None
170194
m._qubits = self._qubits.union(operation.qubits)
@@ -194,7 +218,7 @@ def with_operations(self, *contents: 'cirq.OP_TREE') -> 'cirq.Moment':
194218
if not flattened_contents:
195219
return self
196220

197-
m = Moment()
221+
m = Moment(_flatten_contents=False)
198222
# Use private variables to facilitate a quick copy.
199223
m._qubit_to_op = self._qubit_to_op.copy()
200224
qubits = set(self._qubits)
@@ -483,7 +507,7 @@ def _json_dict_(self) -> Dict[str, Any]:
483507

484508
@classmethod
485509
def _from_json_dict_(cls, operations, **kwargs):
486-
return Moment(operations)
510+
return cls.from_ops(*operations)
487511

488512
def __add__(self, other: 'cirq.OP_TREE') -> 'cirq.Moment':
489513

cirq/circuits/moment_test.py

+7
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,13 @@ def test_operation_at():
170170
assert cirq.Moment([cirq.CZ(a, b), cirq.X(c)]).operation_at(a) == cirq.CZ(a, b)
171171

172172

173+
def test_from_ops():
174+
a = cirq.NamedQubit('a')
175+
b = cirq.NamedQubit('b')
176+
177+
assert cirq.Moment.from_ops(cirq.X(a), cirq.Y(b)) == cirq.Moment(cirq.X(a), cirq.Y(b))
178+
179+
173180
def test_with_operation():
174181
a = cirq.NamedQubit('a')
175182
b = cirq.NamedQubit('b')

0 commit comments

Comments
 (0)