Skip to content

Commit 6937e41

Browse files
authored
Add ClassicalDataStore class to keep track of qubits measured (#4781)
Adds a `ClassicalDataStore` class so we can keep track of which qubits are associated to which measurements. Closes #3232. Initially this was created as part 14 (of 14) of https://tinyurl.com/cirq-feedforward to enable qudits in classical conditions, by storing and using dimensions of the measured qubits when calculating the integer value of each measurement when resolving sympy expressions. However it may have broader applicability. This approach also sets us up to more easily add different types of measurements (#3233, #4274). It will also ease the path to #3002 and #4449., as we can eventually pass this into `Result` rather than the raw `log_of_measurement_results` dictionary. (The return type of `_run` will have to be changed to `Sequence[C;assicalDataStoreReader]`. Related: #887, #3231 (open question @95-martin-orion whether this closes those or not) This PR contains a `ClassicalDataStoreReader` and `ClassicalDataStoreBase` parent "interface" for the `ClassicalDataStore` class as well. This will allow us to swap in different representations that may have different performance characteristics. See #3808 for an example use case. This could be done by adding an optional `ClassicalDataStore` factory method argument to the `SimulatorBase` initializer, or separately to sampler classes. (Note this is an alternative to #4778 for supporting qudits in sympy classical control expressions, as discussed here: https://github.com/quantumlib/Cirq/pull/4778/files#r774816995. The other PR was simpler and less invasive, but a bit hacky. I felt even though bigger, this seemed like the better approach and especially fits better with our future direction, and closed the other one). **Breaking Changes**: 1. The abstract method `SimulatorBase._create_partial_act_on_args` argument `log_of_measurement_results: Dict` has been changed to `classical_data: ClassicalData`. Any third-party simulators that inherit `SimulatorBase` will need to update their implementation accordingly. 2. The abstract base class `ActOnArgs.__init__` argument `log_of_measurement_results: Dict` is now copied before use. For users that depend on the pass-by-reference semantics (this should be rare), they can use the new `classical_data: ClassicalData` argument instead, which is pass-by-reference.
1 parent 467c68d commit 6937e41

32 files changed

+730
-126
lines changed

cirq-core/cirq/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,9 @@
507507
canonicalize_half_turns,
508508
chosen_angle_to_canonical_half_turns,
509509
chosen_angle_to_half_turns,
510+
ClassicalDataDictionaryStore,
511+
ClassicalDataStore,
512+
ClassicalDataStoreReader,
510513
Condition,
511514
Duration,
512515
DURATION_LIKE,
@@ -515,6 +518,7 @@
515518
LinearDict,
516519
MEASUREMENT_KEY_SEPARATOR,
517520
MeasurementKey,
521+
MeasurementType,
518522
PeriodicValue,
519523
RANDOM_STATE_OR_SEED_LIKE,
520524
state_vector_to_probabilities,

cirq-core/cirq/contrib/quimb/mps_simulator.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _create_partial_act_on_args(
9191
self,
9292
initial_state: Union[int, 'MPSState'],
9393
qubits: Sequence['cirq.Qid'],
94-
logs: Dict[str, Any],
94+
classical_data: 'cirq.ClassicalDataStore',
9595
) -> 'MPSState':
9696
"""Creates MPSState args for simulating the Circuit.
9797
@@ -101,7 +101,8 @@ def _create_partial_act_on_args(
101101
qubits: Determines the canonical ordering of the qubits. This
102102
is often used in specifying the initial state, i.e. the
103103
ordering of the computational basis states.
104-
logs: A mutable object that measurements are recorded into.
104+
classical_data: The shared classical data container for this
105+
simulation.
105106
106107
Returns:
107108
MPSState args for simulating the Circuit.
@@ -115,7 +116,7 @@ def _create_partial_act_on_args(
115116
simulation_options=self.simulation_options,
116117
grouping=self.grouping,
117118
initial_state=initial_state,
118-
log_of_measurement_results=logs,
119+
classical_data=classical_data,
119120
)
120121

121122
def _create_step_result(
@@ -229,6 +230,7 @@ def __init__(
229230
grouping: Optional[Dict['cirq.Qid', int]] = None,
230231
initial_state: int = 0,
231232
log_of_measurement_results: Dict[str, Any] = None,
233+
classical_data: 'cirq.ClassicalDataStore' = None,
232234
):
233235
"""Creates and MPSState
234236
@@ -242,11 +244,18 @@ def __init__(
242244
initial_state: An integer representing the initial state.
243245
log_of_measurement_results: A mutable object that measurements are
244246
being recorded into.
247+
classical_data: The shared classical data container for this
248+
simulation.
245249
246250
Raises:
247251
ValueError: If the grouping does not cover the qubits.
248252
"""
249-
super().__init__(prng, qubits, log_of_measurement_results)
253+
super().__init__(
254+
prng=prng,
255+
qubits=qubits,
256+
log_of_measurement_results=log_of_measurement_results,
257+
classical_data=classical_data,
258+
)
250259
qubit_map = self.qubit_map
251260
self.grouping = qubit_map if grouping is None else grouping
252261
if self.grouping.keys() != self.qubit_map.keys():

cirq-core/cirq/contrib/quimb/mps_simulator_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -550,10 +550,10 @@ def test_state_act_on_args_initializer():
550550
s = ccq.mps_simulator.MPSState(
551551
qubits=(cirq.LineQubit(0),),
552552
prng=np.random.RandomState(0),
553-
log_of_measurement_results={'test': 4},
553+
log_of_measurement_results={'test': [4]},
554554
)
555555
assert s.qubits == (cirq.LineQubit(0),)
556-
assert s.log_of_measurement_results == {'test': 4}
556+
assert s.log_of_measurement_results == {'test': [4]}
557557

558558

559559
def test_act_on_gate():

cirq-core/cirq/json_resolver_cache.py

+2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def _parallel_gate_op(gate, qubits):
6565
'Circuit': cirq.Circuit,
6666
'CircuitOperation': cirq.CircuitOperation,
6767
'ClassicallyControlledOperation': cirq.ClassicallyControlledOperation,
68+
'ClassicalDataDictionaryStore': cirq.ClassicalDataDictionaryStore,
6869
'CliffordState': cirq.CliffordState,
6970
'CliffordTableau': cirq.CliffordTableau,
7071
'CNotPowGate': cirq.CNotPowGate,
@@ -107,6 +108,7 @@ def _parallel_gate_op(gate, qubits):
107108
'MixedUnitaryChannel': cirq.MixedUnitaryChannel,
108109
'MeasurementKey': cirq.MeasurementKey,
109110
'MeasurementGate': cirq.MeasurementGate,
111+
'MeasurementType': cirq.MeasurementType,
110112
'_MeasurementSpec': cirq.work._MeasurementSpec,
111113
'Moment': cirq.Moment,
112114
'MutableDensePauliString': cirq.MutableDensePauliString,

cirq-core/cirq/ops/classically_controlled_operation.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ def _circuit_diagram_info_(
148148
sub_info = protocols.circuit_diagram_info(self._sub_operation, sub_args, None)
149149
if sub_info is None:
150150
return NotImplemented # coverage: ignore
151-
152151
control_count = len({k for c in self._conditions for k in c.keys})
153152
wire_symbols = sub_info.wire_symbols + ('^',) * control_count
154153
if any(not isinstance(c, value.KeyCondition) for c in self._conditions):
@@ -176,7 +175,7 @@ def _json_dict_(self) -> Dict[str, Any]:
176175
}
177176

178177
def _act_on_(self, args: 'cirq.OperationTarget') -> bool:
179-
if all(c.resolve(args.log_of_measurement_results) for c in self._conditions):
178+
if all(c.resolve(args.classical_data) for c in self._conditions):
180179
protocols.act_on(self._sub_operation, args)
181180
return True
182181

cirq-core/cirq/ops/classically_controlled_operation_test.py

+36
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
import numpy as np
1416
import pytest
1517
import sympy
1618
from sympy.parsing import sympy_parser
@@ -702,6 +704,40 @@ def test_sympy():
702704
assert result.measurements['m_result'][0][0] == (j > i)
703705

704706

707+
def test_sympy_qudits():
708+
q0 = cirq.LineQid(0, 3)
709+
q1 = cirq.LineQid(1, 5)
710+
q_result = cirq.LineQubit(2)
711+
712+
class PlusGate(cirq.Gate):
713+
def __init__(self, dimension, increment=1):
714+
self.dimension = dimension
715+
self.increment = increment % dimension
716+
717+
def _qid_shape_(self):
718+
return (self.dimension,)
719+
720+
def _unitary_(self):
721+
inc = (self.increment - 1) % self.dimension + 1
722+
u = np.empty((self.dimension, self.dimension))
723+
u[inc:] = np.eye(self.dimension)[:-inc]
724+
u[:inc] = np.eye(self.dimension)[-inc:]
725+
return u
726+
727+
for i in range(15):
728+
digits = cirq.big_endian_int_to_digits(i, digit_count=2, base=(3, 5))
729+
circuit = cirq.Circuit(
730+
PlusGate(3, digits[0]).on(q0),
731+
PlusGate(5, digits[1]).on(q1),
732+
cirq.measure(q0, q1, key='m'),
733+
cirq.X(q_result).with_classical_controls(sympy_parser.parse_expr('m % 4 <= 1')),
734+
cirq.measure(q_result, key='m_result'),
735+
)
736+
737+
result = cirq.Simulator().run(circuit)
738+
assert result.measurements['m_result'][0][0] == (i % 4 <= 1)
739+
740+
705741
def test_sympy_path_prefix():
706742
q = cirq.LineQubit(0)
707743
op = cirq.X(q).with_classical_controls(sympy.Symbol('b'))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
{
2+
"cirq_type": "ClassicalDataDictionaryStore",
3+
"measurements": [
4+
[
5+
{
6+
"cirq_type": "MeasurementKey",
7+
"name": "m",
8+
"path": []
9+
},
10+
[0, 1]
11+
]
12+
],
13+
"measured_qubits": [
14+
[
15+
{
16+
"cirq_type": "MeasurementKey",
17+
"name": "m",
18+
"path": []
19+
},
20+
[
21+
{
22+
"cirq_type": "LineQubit",
23+
"x": 0
24+
},
25+
{
26+
"cirq_type": "LineQubit",
27+
"x": 1
28+
}
29+
]
30+
]
31+
],
32+
"channel_measurements": [
33+
[
34+
{
35+
"cirq_type": "MeasurementKey",
36+
"name": "c",
37+
"path": []
38+
},
39+
3
40+
]
41+
],
42+
"measurement_types": [
43+
[
44+
{
45+
"cirq_type": "MeasurementKey",
46+
"name": "m",
47+
"path": []
48+
},
49+
1
50+
],
51+
[
52+
{
53+
"cirq_type": "MeasurementKey",
54+
"name": "c",
55+
"path": []
56+
},
57+
2
58+
]
59+
]
60+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
cirq.ClassicalDataDictionaryStore(_measurements={cirq.MeasurementKey('m'): [0, 1]}, _measured_qubits={cirq.MeasurementKey('m'): [cirq.LineQubit(0), cirq.LineQubit(1)]}, _channel_measurements={cirq.MeasurementKey('c'): 3}, _measurement_types={cirq.MeasurementKey('m'): cirq.MeasurementType.MEASUREMENT, cirq.MeasurementKey('c'): cirq.MeasurementType.CHANNEL})
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[1, 2]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[cirq.MeasurementType.MEASUREMENT, cirq.MeasurementType.CHANNEL]

cirq-core/cirq/protocols/measurement_key_protocol.py

-3
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
from cirq import value
2121
from cirq._doc import doc_private
2222

23-
if TYPE_CHECKING:
24-
import cirq
25-
2623
if TYPE_CHECKING:
2724
import cirq
2825

cirq-core/cirq/sim/act_on_args.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
import numpy as np
3333

34-
from cirq import protocols, ops
34+
from cirq import ops, protocols, value
3535
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
3636
from cirq.sim.operation_target import OperationTarget
3737

@@ -50,6 +50,7 @@ def __init__(
5050
qubits: Optional[Sequence['cirq.Qid']] = None,
5151
log_of_measurement_results: Optional[Dict[str, List[int]]] = None,
5252
ignore_measurement_results: bool = False,
53+
classical_data: Optional['cirq.ClassicalDataStore'] = None,
5354
):
5455
"""Inits ActOnArgs.
5556
@@ -65,16 +66,21 @@ def __init__(
6566
will treat measurement as dephasing instead of collapsing
6667
process, and not log the result. This is only applicable to
6768
simulators that can represent mixed states.
69+
classical_data: The shared classical data container for this
70+
simulation.
6871
"""
6972
if prng is None:
7073
prng = cast(np.random.RandomState, np.random)
7174
if qubits is None:
7275
qubits = ()
73-
if log_of_measurement_results is None:
74-
log_of_measurement_results = {}
7576
self._set_qubits(qubits)
7677
self.prng = prng
77-
self._log_of_measurement_results = log_of_measurement_results
78+
self._classical_data = classical_data or value.ClassicalDataDictionaryStore(
79+
_measurements={
80+
value.MeasurementKey.parse_serialized(k): tuple(v)
81+
for k, v in (log_of_measurement_results or {}).items()
82+
}
83+
)
7884
self._ignore_measurement_results = ignore_measurement_results
7985

8086
def _set_qubits(self, qubits: Sequence['cirq.Qid']):
@@ -103,9 +109,9 @@ def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[
103109
return
104110
bits = self._perform_measurement(qubits)
105111
corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(bits, invert_mask)]
106-
if key in self._log_of_measurement_results:
107-
raise ValueError(f"Measurement already logged to key {key!r}")
108-
self._log_of_measurement_results[key] = corrected
112+
self._classical_data.record_measurement(
113+
value.MeasurementKey.parse_serialized(key), corrected, qubits
114+
)
109115

110116
def get_axes(self, qubits: Sequence['cirq.Qid']) -> List[int]:
111117
return [self.qubit_map[q] for q in qubits]
@@ -138,7 +144,7 @@ def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
138144
DeprecationWarning,
139145
)
140146
self._on_copy(args)
141-
args._log_of_measurement_results = self.log_of_measurement_results.copy()
147+
args._classical_data = self._classical_data.copy()
142148
return args
143149

144150
def _on_copy(self: TSelf, args: TSelf, deep_copy_buffers: bool = True):
@@ -236,8 +242,8 @@ def _on_transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid'], targ
236242
functionality, if supported."""
237243

238244
@property
239-
def log_of_measurement_results(self) -> Dict[str, List[int]]:
240-
return self._log_of_measurement_results
245+
def classical_data(self) -> 'cirq.ClassicalDataStoreReader':
246+
return self._classical_data
241247

242248
@property
243249
def ignore_measurement_results(self) -> bool:

0 commit comments

Comments
 (0)