Skip to content

Commit 5f1f238

Browse files
authored
Fix CCO related nits in cirq.Operation and cirq.TaggedOperation (quantumlib#5390)
* Improve CCO support in cirq.Operation and cirq.TaggedOperation * Fix failing tests * Update docstrings * Clarify without_classical_controls docstring * Reword docstrings
1 parent c19e71d commit 5f1f238

4 files changed

+38
-10
lines changed

cirq/ops/classically_controlled_operation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def _resolve_parameters_(
134134
self, resolver: 'cirq.ParamResolver', recursive: bool
135135
) -> 'ClassicallyControlledOperation':
136136
new_sub_op = protocols.resolve_parameters(self._sub_operation, resolver, recursive)
137-
return new_sub_op.with_classical_controls(*self._conditions)
137+
return ClassicallyControlledOperation(new_sub_op, self._conditions)
138138

139139
def _circuit_diagram_info_(
140140
self, args: 'cirq.CircuitDiagramInfoArgs'

cirq/ops/classically_controlled_operation_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def test_condition_removal():
385385
op = op.without_classical_controls()
386386
assert not cirq.control_keys(op)
387387
assert not op.classical_controls
388-
assert set(map(str, op.tags)) == {'t1'}
388+
assert not op.tags
389389

390390

391391
def test_qubit_mapping():

cirq/ops/raw_types.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -622,27 +622,33 @@ def classical_controls(self) -> FrozenSet['cirq.Condition']:
622622

623623
def with_classical_controls(
624624
self, *conditions: Union[str, 'cirq.MeasurementKey', 'cirq.Condition', sympy.Expr]
625-
) -> 'cirq.ClassicallyControlledOperation':
625+
) -> 'cirq.Operation':
626626
"""Returns a classically controlled version of this operation.
627627
628628
An operation that is classically controlled is executed iff all
629629
conditions evaluate to True. Currently the only condition type is a
630630
measurement key. A measurement key evaluates to True iff any qubit in
631631
the corresponding measurement operation evaluated to a non-zero value.
632632
633-
The classical control will hide any tags on the existing operation,
634-
since tags are considered a local attribute.
633+
If no conditions are specified, returns self.
634+
635+
The classical control will remove any tags on the existing operation,
636+
since tags are fragile, and we always opt to get rid of the tags when
637+
the underlying operation is changed.
635638
636639
Args:
637640
*conditions: A list of measurement keys, strings that can be parsed
638641
into measurement keys, or sympy expressions where the free
639642
symbols are measurement key strings.
640643
641644
Returns:
642-
A `ClassicallyControlledOperation` wrapping the operation.
645+
A `ClassicallyControlledOperation` wrapping the operation. If no conditions
646+
are specified, returns self.
643647
"""
644648
from cirq.ops.classically_controlled_operation import ClassicallyControlledOperation
645649

650+
if not conditions:
651+
return self
646652
return ClassicallyControlledOperation(self, conditions)
647653

648654
def without_classical_controls(self) -> 'cirq.Operation':
@@ -655,10 +661,10 @@ def without_classical_controls(self) -> 'cirq.Operation':
655661
If there are no classical controls on the operation, it will return
656662
`self`.
657663
658-
Since tags are considered local, this will also remove any tags from
659-
the operation (unless there are no classical controls on it). If a
660-
`TaggedOperation` is under all the classical control layers, that
661-
`TaggedOperation` will be returned from this function.
664+
Since tags are fragile, this will also remove any tags from the operation,
665+
when called on `TaggedOperation` (unless there are no classical controls on it).
666+
If a `TaggedOperation` is under all the classical control layers,
667+
that `TaggedOperation` will be returned from this function.
662668
663669
Returns:
664670
The operation with all classical controls removed.
@@ -712,6 +718,8 @@ def controlled_by(
712718
*control_qubits: 'cirq.Qid',
713719
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
714720
) -> 'cirq.Operation':
721+
if len(control_qubits) == 0:
722+
return self
715723
return self.sub_operation.controlled_by(*control_qubits, control_values=control_values)
716724

717725
@property
@@ -864,6 +872,13 @@ def without_classical_controls(self) -> 'cirq.Operation':
864872
new_sub_operation = self.sub_operation.without_classical_controls()
865873
return self if new_sub_operation is self.sub_operation else new_sub_operation
866874

875+
def with_classical_controls(
876+
self, *conditions: Union[str, 'cirq.MeasurementKey', 'cirq.Condition', sympy.Expr]
877+
) -> 'cirq.Operation':
878+
if not conditions:
879+
return self
880+
return self.sub_operation.with_classical_controls(*conditions)
881+
867882
def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']:
868883
return protocols.control_keys(self.sub_operation)
869884

cirq/ops/raw_types_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,10 @@ def test_tagged_operation_forwards_protocols():
627627
assert controlled_y.qubits == (q2, q1)
628628
assert isinstance(controlled_y, cirq.Operation)
629629
assert not isinstance(controlled_y, cirq.TaggedOperation)
630+
classically_controlled_y = tagged_y.with_classical_controls("a")
631+
assert classically_controlled_y == y.with_classical_controls("a")
632+
assert isinstance(classically_controlled_y, cirq.Operation)
633+
assert not isinstance(classically_controlled_y, cirq.TaggedOperation)
630634

631635
clifford_x = cirq.SingleQubitCliffordGate.X(q1)
632636
tagged_x = cirq.SingleQubitCliffordGate.X(q1).with_tags(tag)
@@ -932,3 +936,12 @@ def __iter__(self):
932936
raise NotImplementedError()
933937

934938
assert cirq.H.on_each(QidIter())[0] == cirq.H.on(QidIter())
939+
940+
941+
@pytest.mark.parametrize(
942+
'op', [cirq.X(cirq.NamedQubit("q")), cirq.X(cirq.NamedQubit("q")).with_tags("tagged_op")]
943+
)
944+
def test_with_methods_return_self_on_empty_conditions(op):
945+
assert op is op.with_tags(*[])
946+
assert op is op.with_classical_controls(*[])
947+
assert op is op.controlled_by(*[])

0 commit comments

Comments
 (0)