Skip to content

Commit db9552a

Browse files
authored
Merge branch 'master' into stream-client/fake-quantum-run-stream-timing-support
2 parents 2ef0730 + 6fae409 commit db9552a

File tree

98 files changed

+2060
-838
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

98 files changed

+2060
-838
lines changed

Diff for: cirq-core/cirq/circuits/circuit.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -272,12 +272,15 @@ def __getitem__(self, key):
272272
def __str__(self) -> str:
273273
return self.to_text_diagram()
274274

275-
def __repr__(self) -> str:
276-
cls_name = self.__class__.__name__
275+
def _repr_args(self) -> str:
277276
args = []
278277
if self.moments:
279278
args.append(_list_repr_with_indented_item_lines(self.moments))
280-
return f'cirq.{cls_name}({", ".join(args)})'
279+
return f'{", ".join(args)}'
280+
281+
def __repr__(self) -> str:
282+
cls_name = self.__class__.__name__
283+
return f'cirq.{cls_name}({self._repr_args()})'
281284

282285
def _repr_pretty_(self, p: Any, cycle: bool) -> None:
283286
"""Print ASCII diagram in Jupyter."""
@@ -1791,7 +1794,6 @@ def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'):
17911794

17921795
# "mop" means current moment-or-operation
17931796
for mop in ops.flatten_to_ops_or_moments(contents):
1794-
17951797
# Identify the index of the moment to place this `mop` into.
17961798
placement_index = get_earliest_accommodating_moment_index(
17971799
mop, qubit_indices, mkey_indices, ckey_indices, length
@@ -2450,7 +2452,6 @@ def _draw_moment_annotations(
24502452
first_annotation_row: int,
24512453
transpose: bool,
24522454
):
2453-
24542455
for k, annotation in enumerate(_get_moment_annotations(moment)):
24552456
args = protocols.CircuitDiagramInfoArgs(
24562457
known_qubits=(),

Diff for: cirq-core/cirq/circuits/frozen_circuit.py

+74-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""An immutable version of the Circuit data structure."""
15-
from typing import AbstractSet, FrozenSet, Iterable, Iterator, Sequence, Tuple, TYPE_CHECKING, Union
15+
from typing import (
16+
AbstractSet,
17+
FrozenSet,
18+
Hashable,
19+
Iterable,
20+
Iterator,
21+
Sequence,
22+
Tuple,
23+
TYPE_CHECKING,
24+
Union,
25+
)
1626

1727
import numpy as np
1828

@@ -34,7 +44,10 @@ class FrozenCircuit(AbstractCircuit, protocols.SerializableByKey):
3444
"""
3545

3646
def __init__(
37-
self, *contents: 'cirq.OP_TREE', strategy: 'cirq.InsertStrategy' = InsertStrategy.EARLIEST
47+
self,
48+
*contents: 'cirq.OP_TREE',
49+
strategy: 'cirq.InsertStrategy' = InsertStrategy.EARLIEST,
50+
tags: Sequence[Hashable] = (),
3851
) -> None:
3952
"""Initializes a frozen circuit.
4053
@@ -47,9 +60,14 @@ def __init__(
4760
strategy: When initializing the circuit with operations and moments
4861
from `contents`, this determines how the operations are packed
4962
together.
63+
tags: A sequence of any type of object that is useful to attach metadata
64+
to this circuit as long as the type is hashable. If you wish the
65+
resulting circuit to be eventually serialized into JSON, you should
66+
also restrict the tags to be JSON serializable.
5067
"""
5168
base = Circuit(contents, strategy=strategy)
5269
self._moments = tuple(base.moments)
70+
self._tags = tuple(tags)
5371

5472
@classmethod
5573
def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
@@ -61,10 +79,35 @@ def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
6179
def moments(self) -> Sequence['cirq.Moment']:
6280
return self._moments
6381

82+
@property
83+
def tags(self) -> Tuple[Hashable, ...]:
84+
"""Returns a tuple of the Circuit's tags."""
85+
return self._tags
86+
87+
@_compat.cached_property
88+
def untagged(self) -> 'cirq.FrozenCircuit':
89+
"""Returns the underlying FrozenCircuit without any tags."""
90+
return self._from_moments(self._moments) if self.tags else self
91+
92+
def with_tags(self, *new_tags: Hashable) -> 'cirq.FrozenCircuit':
93+
"""Creates a new tagged `FrozenCircuit` with `self.tags` and `new_tags` combined."""
94+
if not new_tags:
95+
return self
96+
new_circuit = FrozenCircuit(tags=self.tags + new_tags)
97+
new_circuit._moments = self._moments
98+
return new_circuit
99+
64100
@_compat.cached_method
65101
def __hash__(self) -> int:
66102
# Explicitly cached for performance
67-
return hash((self.moments,))
103+
return hash((self.moments, self.tags))
104+
105+
def __eq__(self, other):
106+
super_eq = super().__eq__(other)
107+
if super_eq is not True:
108+
return super_eq
109+
other_tags = other.tags if isinstance(other, FrozenCircuit) else ()
110+
return self.tags == other_tags
68111

69112
def __getstate__(self):
70113
# Don't save hash when pickling; see #3777.
@@ -130,11 +173,23 @@ def all_measurement_key_names(self) -> FrozenSet[str]:
130173

131174
@_compat.cached_method
132175
def _is_parameterized_(self) -> bool:
133-
return super()._is_parameterized_()
176+
return super()._is_parameterized_() or any(
177+
protocols.is_parameterized(tag) for tag in self.tags
178+
)
134179

135180
@_compat.cached_method
136181
def _parameter_names_(self) -> AbstractSet[str]:
137-
return super()._parameter_names_()
182+
tag_params = {name for tag in self.tags for name in protocols.parameter_names(tag)}
183+
return super()._parameter_names_() | tag_params
184+
185+
def _resolve_parameters_(
186+
self, resolver: 'cirq.ParamResolver', recursive: bool
187+
) -> 'cirq.FrozenCircuit':
188+
resolved_circuit = super()._resolve_parameters_(resolver, recursive)
189+
resolved_tags = [
190+
protocols.resolve_parameters(tag, resolver, recursive) for tag in self.tags
191+
]
192+
return resolved_circuit.with_tags(*resolved_tags)
138193

139194
def _measurement_key_names_(self) -> FrozenSet[str]:
140195
return self.all_measurement_key_names()
@@ -161,6 +216,20 @@ def __pow__(self, other) -> 'cirq.FrozenCircuit':
161216
except:
162217
return NotImplemented
163218

219+
def _repr_args(self) -> str:
220+
moments_repr = super()._repr_args()
221+
tag_repr = ','.join(_compat.proper_repr(t) for t in self._tags)
222+
return f'{moments_repr}, tags=[{tag_repr}]' if self.tags else moments_repr
223+
224+
def _json_dict_(self):
225+
attribute_names = ['moments', 'tags'] if self.tags else ['moments']
226+
ret = protocols.obj_to_dict_helper(self, attribute_names)
227+
return ret
228+
229+
@classmethod
230+
def _from_json_dict_(cls, moments, *, tags=(), **kwargs):
231+
return cls(moments, strategy=InsertStrategy.EARLIEST, tags=tags)
232+
164233
def concat_ragged(
165234
*circuits: 'cirq.AbstractCircuit', align: Union['cirq.Alignment', str] = Alignment.LEFT
166235
) -> 'cirq.FrozenCircuit':

Diff for: cirq-core/cirq/circuits/frozen_circuit_test.py

+32
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""
1818

1919
import pytest
20+
import sympy
2021

2122
import cirq
2223

@@ -74,3 +75,34 @@ def test_immutable():
7475
match="(can't set attribute)|(property 'moments' of 'FrozenCircuit' object has no setter)",
7576
):
7677
c.moments = (cirq.Moment(cirq.H(q)), cirq.Moment(cirq.X(q)))
78+
79+
80+
def test_tagged_circuits():
81+
q = cirq.LineQubit(0)
82+
ops = [cirq.X(q), cirq.H(q)]
83+
tags = [sympy.Symbol("a"), "b"]
84+
circuit = cirq.Circuit(ops)
85+
frozen_circuit = cirq.FrozenCircuit(ops)
86+
tagged_circuit = cirq.FrozenCircuit(ops, tags=tags)
87+
# Test equality
88+
assert tagged_circuit.tags == tuple(tags)
89+
assert circuit == frozen_circuit != tagged_circuit
90+
assert cirq.approx_eq(circuit, frozen_circuit)
91+
assert cirq.approx_eq(frozen_circuit, tagged_circuit)
92+
# Test hash
93+
assert hash(frozen_circuit) != hash(tagged_circuit)
94+
# Test _repr_ and _json_ round trips.
95+
cirq.testing.assert_equivalent_repr(tagged_circuit)
96+
cirq.testing.assert_json_roundtrip_works(tagged_circuit)
97+
# Test utility methods and constructors
98+
assert frozen_circuit.with_tags() is frozen_circuit
99+
assert frozen_circuit.with_tags(*tags) == tagged_circuit
100+
assert tagged_circuit.with_tags("c") == cirq.FrozenCircuit(ops, tags=[*tags, "c"])
101+
assert tagged_circuit.untagged == frozen_circuit
102+
assert frozen_circuit.untagged is frozen_circuit
103+
# Test parameterized protocols
104+
assert cirq.is_parameterized(frozen_circuit) is False
105+
assert cirq.is_parameterized(tagged_circuit) is True
106+
assert cirq.parameter_names(tagged_circuit) == {"a"}
107+
# Tags are not propagated to diagrams yet.
108+
assert str(frozen_circuit) == str(tagged_circuit)

Diff for: cirq-core/cirq/experiments/random_quantum_circuit_generation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -693,5 +693,5 @@ def _two_qubit_layer(
693693
prng: 'np.random.RandomState',
694694
) -> 'cirq.OP_TREE':
695695
for a, b in coupled_qubit_pairs:
696-
if (a, b) in layer:
696+
if (a, b) in layer or (b, a) in layer:
697697
yield two_qubit_op_factory(a, b, prng)

Diff for: cirq-core/cirq/experiments/random_quantum_circuit_generation_test.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,30 @@ def __init__(self):
259259
'seed, expected_circuit_length, single_qubit_layers_slice, '
260260
'two_qubit_layers_slice',
261261
(
262+
(
263+
(cirq.q(0, 0), cirq.q(0, 1), cirq.q(0, 2)),
264+
4,
265+
lambda a, b, _: cirq.CZ(a, b),
266+
[[(cirq.q(0, 0), cirq.q(0, 1))], [(cirq.q(0, 1), cirq.q(0, 2))]],
267+
(cirq.X**0.5,),
268+
True,
269+
1234,
270+
9,
271+
slice(None, None, 2),
272+
slice(1, None, 2),
273+
),
274+
(
275+
(cirq.q(0, 0), cirq.q(0, 1), cirq.q(0, 2)),
276+
4,
277+
lambda a, b, _: cirq.CZ(a, b),
278+
[[(cirq.q(0, 1), cirq.q(0, 0))], [(cirq.q(0, 1), cirq.q(0, 2))]],
279+
(cirq.X**0.5,),
280+
True,
281+
1234,
282+
9,
283+
slice(None, None, 2),
284+
slice(1, None, 2),
285+
),
262286
(
263287
cirq.GridQubit.rect(4, 3),
264288
20,
@@ -406,7 +430,10 @@ def _validate_two_qubit_layers(
406430
# Operation is two-qubit
407431
assert cirq.num_qubits(op) == 2
408432
# Operation fits pattern
409-
assert op.qubits in pattern[i % len(pattern)]
433+
assert (
434+
op.qubits in pattern[i % len(pattern)]
435+
or op.qubits[::-1] in pattern[i % len(pattern)]
436+
)
410437
active_pairs.add(op.qubits)
411438
# All interactions that should be in this layer are present
412439
assert all(

Diff for: cirq-core/cirq/ops/dense_pauli_string.py

+3
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,9 @@ def copy(
570570
def __str__(self) -> str:
571571
return super().__str__() + ' (mutable)'
572572

573+
def _value_equality_values_(self):
574+
return self.coefficient, tuple(PAULI_CHARS[p] for p in self.pauli_mask)
575+
573576
@classmethod
574577
def inline_gaussian_elimination(cls, rows: 'List[MutableDensePauliString]') -> None:
575578
if not rows:

Diff for: cirq-core/cirq/ops/linear_combinations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def _pauli_string_from_unit(unit: UnitPauliStringT, coefficient: Union[int, floa
357357
return PauliString(qubit_pauli_map=dict(unit), coefficient=coefficient)
358358

359359

360-
@value.value_equality(approximate=True)
360+
@value.value_equality(approximate=True, unhashable=True)
361361
class PauliSum:
362362
"""Represents operator defined by linear combination of PauliStrings.
363363

Diff for: cirq-core/cirq/protocols/json_serialization.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def _register_resolver(dict_factory: Callable[[], Dict[str, ObjectFactory]]) ->
113113
class SupportsJSON(Protocol):
114114
"""An object that can be turned into JSON dictionaries.
115115
116-
The magic method _json_dict_ must return a trivially json-serializable
116+
The magic method `_json_dict_` must return a trivially json-serializable
117117
type or other objects that support the SupportsJSON protocol.
118118
119119
During deserialization, a class must be able to be resolved (see
@@ -150,7 +150,7 @@ def obj_to_dict_helper(obj: Any, attribute_names: Iterable[str]) -> Dict[str, An
150150
"""Construct a dictionary containing attributes from obj
151151
152152
This is useful as a helper function in objects implementing the
153-
SupportsJSON protocol, particularly in the _json_dict_ method.
153+
SupportsJSON protocol, particularly in the `_json_dict_` method.
154154
155155
In addition to keys and values specified by `attribute_names`, the
156156
returned dictionary has an additional key "cirq_type" whose value
@@ -169,7 +169,7 @@ def obj_to_dict_helper(obj: Any, attribute_names: Iterable[str]) -> Dict[str, An
169169

170170
# pylint: enable=redefined-builtin
171171
def dataclass_json_dict(obj: Any) -> Dict[str, Any]:
172-
"""Return a dictionary suitable for _json_dict_ from a dataclass.
172+
"""Return a dictionary suitable for `_json_dict_` from a dataclass.
173173
174174
Dataclasses keep track of their relevant fields, so we can automatically generate these.
175175
@@ -607,7 +607,7 @@ def to_json(
607607
cls: Passed to json.dump; the default value of CirqEncoder
608608
enables the serialization of Cirq objects which implement
609609
the SupportsJSON protocol. To support serialization of 3rd
610-
party classes, prefer adding the _json_dict_ magic method
610+
party classes, prefer adding the `_json_dict_` magic method
611611
to your classes rather than overriding this default.
612612
"""
613613
if has_serializable_by_keys(obj):

Diff for: cirq-core/cirq/sim/density_matrix_simulation_state.py

+16
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,22 @@ def __init__(
285285
)
286286
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)
287287

288+
def add_qubits(self, qubits: Sequence['cirq.Qid']):
289+
ret = super().add_qubits(qubits)
290+
return (
291+
self.kronecker_product(type(self)(qubits=qubits), inplace=True)
292+
if ret is NotImplemented
293+
else ret
294+
)
295+
296+
def remove_qubits(self, qubits: Sequence['cirq.Qid']):
297+
ret = super().remove_qubits(qubits)
298+
if ret is not NotImplemented:
299+
return ret
300+
extracted, remainder = self.factor(qubits, inplace=True)
301+
remainder._state._density_matrix *= extracted._state._density_matrix.reshape(-1)[0]
302+
return remainder
303+
288304
def _act_on_fallback_(
289305
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
290306
) -> bool:

Diff for: cirq-core/cirq/sim/simulation_state_test.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,8 @@ def test_delegating_gate_channel(exp):
164164
control_circuit = cirq.Circuit(cirq.H(q))
165165
control_circuit.append(cirq.ZPowGate(exponent=exp).on(q))
166166

167-
with pytest.raises(TypeError, match="DensityMatrixSimulator doesn't support"):
168-
# TODO: This test should pass once we extend support to DensityMatrixSimulator.
169-
assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)
167+
assert_test_circuit_for_sv_simulator(test_circuit, control_circuit)
168+
assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)
170169

171170

172171
@pytest.mark.parametrize('num_ancilla', [1, 2, 3])

0 commit comments

Comments
 (0)