diff --git a/cirq-core/cirq/devices/noise_model.py b/cirq-core/cirq/devices/noise_model.py index e8430dc4fcb..bf00dd29c5d 100644 --- a/cirq-core/cirq/devices/noise_model.py +++ b/cirq-core/cirq/devices/noise_model.py @@ -88,12 +88,7 @@ def is_virtual_moment(self, moment: 'cirq.Moment') -> bool: """ if not moment.operations: return False - return all( - [ - isinstance(op, ops.TaggedOperation) and ops.VirtualTag() in op.tags - for op in moment.operations - ] - ) + return all(ops.VirtualTag() in op.tags for op in moment) def _noisy_moments_impl_moment( self, moments: 'Iterable[cirq.Moment]', system_qubits: Sequence['cirq.Qid'] diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index 36afcc7624b..4ec3059a424 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -425,7 +425,7 @@ def untagged(self) -> 'cirq.Operation': """Returns the underlying operation without any tags.""" return self - def with_tags(self, *new_tags: Hashable) -> 'cirq.TaggedOperation': + def with_tags(self, *new_tags: Hashable) -> 'cirq.Operation': """Creates a new TaggedOperation, with this op and the specified tags. This method can be used to attach meta-data to specific operations @@ -443,6 +443,8 @@ def with_tags(self, *new_tags: Hashable) -> 'cirq.TaggedOperation': Args: new_tags: The tags to wrap this operation in. """ + if not new_tags: + return self return TaggedOperation(self, *new_tags) def transform_qubits( @@ -608,6 +610,8 @@ def with_tags(self, *new_tags: Hashable) -> 'cirq.TaggedOperation': that has the tags of this operation combined with the new_tags specified as the parameter. """ + if not new_tags: + return self return TaggedOperation(self.sub_operation, *self._tags, *new_tags) def __str__(self) -> str: diff --git a/cirq-core/cirq/ops/raw_types_test.py b/cirq-core/cirq/ops/raw_types_test.py index 08cf6a9c910..22238b1bec9 100644 --- a/cirq-core/cirq/ops/raw_types_test.py +++ b/cirq-core/cirq/ops/raw_types_test.py @@ -434,6 +434,14 @@ def test_tagged_operation(): assert not cirq.is_measurement(op) +def test_with_tags_returns_same_instance_if_possible(): + untagged = cirq.X(cirq.GridQubit(1, 1)) + assert untagged.with_tags() is untagged + + tagged = untagged.with_tags('foo') + assert tagged.with_tags() is tagged + + def test_tagged_measurement(): assert not cirq.is_measurement(cirq.GlobalPhaseOperation(coefficient=-1.0).with_tags('tag0'))