From 9a9db4c5192ccdce73031d8c6c1f7e0e325c0208 Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Sun, 17 May 2020 19:43:17 -0700 Subject: [PATCH 1/6] Add act_on protocol - Add act_on_protocol.py - Add _act_on_ to MeasurementGate and ResetChannel - Add ActOnStateVectorArgs to represent state vector simulator state - Move logic out of sparse_simulator.py and into ActOnStateVectorArgs._fallback_act_on - Fix ApplyUnitaryArgs.subspace_index not forwarding shape information - Add allow_decompose to several protocol methods - Fix CliffordSimulator accepting measurements that may not be stabilizers - Fix _verify_unique_measurement_keys not using measurement_key**s** - Add test that terminal measurement simulations can have random results - with_gate now returns self if new_gate is gate - GateOperation now delegates protocols directly to its gate when possible --- cirq/__init__.py | 2 + cirq/ops/common_channels.py | 25 ++- cirq/ops/gate_operation.py | 64 ++++-- cirq/ops/gate_operation_test.py | 52 +++++ cirq/ops/measurement_gate.py | 25 ++- cirq/protocols/__init__.py | 4 + cirq/protocols/act_on_protocol.py | 121 ++++++++++ cirq/protocols/apply_unitary_protocol.py | 21 +- cirq/protocols/apply_unitary_protocol_test.py | 6 +- cirq/protocols/has_unitary_protocol.py | 4 +- cirq/protocols/has_unitary_protocol_test.py | 6 + cirq/protocols/measurement_key_protocol.py | 3 +- cirq/protocols/mixture_protocol.py | 24 +- cirq/protocols/mixture_protocol_test.py | 22 ++ cirq/sim/__init__.py | 3 + cirq/sim/act_on_state_vector_args.py | 212 ++++++++++++++++++ cirq/sim/clifford/clifford_simulator.py | 16 +- cirq/sim/simulator.py | 4 +- cirq/sim/sparse_simulator.py | 207 ++++------------- cirq/sim/sparse_simulator_test.py | 7 + docs/api.rst | 2 + 21 files changed, 631 insertions(+), 199 deletions(-) create mode 100644 cirq/protocols/act_on_protocol.py create mode 100644 cirq/sim/act_on_state_vector_args.py diff --git a/cirq/__init__.py b/cirq/__init__.py index 6c4f2cdb9e9..879abd66789 100644 --- a/cirq/__init__.py +++ b/cirq/__init__.py @@ -315,6 +315,7 @@ ) from cirq.sim import ( + ActOnStateVectorArgs, StabilizerStateChForm, CIRCUIT_LIKE, CliffordSimulator, @@ -394,6 +395,7 @@ # pylint: disable=redefined-builtin from cirq.protocols import ( + act_on, apply_channel, apply_mixture, apply_unitaries, diff --git a/cirq/ops/common_channels.py b/cirq/ops/common_channels.py index f2ed4d67140..024adec42ec 100644 --- a/cirq/ops/common_channels.py +++ b/cirq/ops/common_channels.py @@ -19,7 +19,7 @@ import numpy as np -from cirq import protocols, value +from cirq import protocols, value, linalg from cirq.ops import (raw_types, common_gates, pauli_gates, gate_features, identity) @@ -552,6 +552,29 @@ def __init__(self, dimension: int = 2) -> None: def _qid_shape_(self): return (self._dimension,) + def _act_on_(self, args: Any): + from cirq import sim + + if isinstance(args, sim.ActOnStateVectorArgs): + # Do a silent measurement. + measurements, _ = sim.measure_state_vector( + args.target_tensor, + args.axes, + out=args.target_tensor, + qid_shape=args.target_tensor.shape) + result = measurements[0] + + # Use measurement result to zero the qid. + if result: + zero = args.subspace_index(0) + other = args.subspace_index(result) + args.target_tensor[zero] = args.target_tensor[other] + args.target_tensor[other] = 0 + + return True + + return NotImplemented + def _channel_(self) -> Iterable[np.ndarray]: # The first axis is over the list of channel matrices channel = np.zeros((self._dimension,) * 3, dtype=np.complex64) diff --git a/cirq/ops/gate_operation.py b/cirq/ops/gate_operation.py index d7300b28a24..ecdda3632ef 100644 --- a/cirq/ops/gate_operation.py +++ b/cirq/ops/gate_operation.py @@ -16,7 +16,7 @@ import re from typing import (Any, Dict, FrozenSet, Iterable, List, Optional, Sequence, - Tuple, TypeVar, Union, TYPE_CHECKING) + Tuple, TypeVar, Union, TYPE_CHECKING, Type) import numpy as np @@ -56,6 +56,8 @@ def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'cirq.Operation': return self.gate.on(*new_qubits) def with_gate(self, new_gate: 'cirq.Gate') -> 'cirq.Operation': + if self.gate is new_gate: + return self return new_gate.on(*self.qubits) def __repr__(self): @@ -103,7 +105,7 @@ def _value_equality_values_(self): return self.gate, self._group_interchangeable_qubits() def _qid_shape_(self): - return protocols.qid_shape(self.gate) + return self.gate._qid_shape_() def _num_qubits_(self): return len(self._qubits) @@ -114,33 +116,57 @@ def _decompose_(self) -> 'cirq.OP_TREE': NotImplemented) def _pauli_expansion_(self) -> value.LinearDict[str]: - return protocols.pauli_expansion(self.gate) + getter = getattr(self.gate, '_pauli_expansion_', None) + if getter is not None: + return getter() + return NotImplemented def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs' ) -> Union[np.ndarray, None, NotImplementedType]: - return protocols.apply_unitary(self.gate, args, default=None) + getter = getattr(self.gate, '_apply_unitary_', None) + if getter is not None: + return getter(args) + return NotImplemented def _has_unitary_(self) -> bool: - return protocols.has_unitary(self.gate) + getter = getattr(self.gate, '_has_unitary_', None) + if getter is not None: + return getter() + return NotImplemented def _unitary_(self) -> Union[np.ndarray, NotImplementedType]: - return protocols.unitary(self.gate, default=None) + getter = getattr(self.gate, '_unitary_', None) + if getter is not None: + return getter() + return NotImplemented def _commutes_(self, other: Any, atol: float) -> Union[bool, NotImplementedType, None]: return self.gate._commutes_on_qids_(self.qubits, other, atol=atol) def _has_mixture_(self) -> bool: - return protocols.has_mixture(self.gate) + getter = getattr(self.gate, '_has_mixture_', None) + if getter is not None: + return getter() + return NotImplemented def _mixture_(self) -> Sequence[Tuple[float, Any]]: - return protocols.mixture(self.gate, NotImplemented) + getter = getattr(self.gate, '_mixture_', None) + if getter is not None: + return getter() + return NotImplemented def _has_channel_(self) -> bool: - return protocols.has_channel(self.gate) + getter = getattr(self.gate, '_has_channel_', None) + if getter is not None: + return getter() + return NotImplemented def _channel_(self) -> Union[Tuple[np.ndarray], NotImplementedType]: - return protocols.channel(self.gate, NotImplemented) + getter = getattr(self.gate, '_channel_', None) + if getter is not None: + return getter() + return NotImplemented def _measurement_key_(self) -> Optional[str]: getter = getattr(self.gate, '_measurement_key_', None) @@ -154,12 +180,21 @@ def _measurement_keys_(self) -> Optional[Iterable[str]]: return getter() return NotImplemented + def _act_on_(self, args: Any): + getter = getattr(self.gate, '_act_on_', None) + if getter is not None: + return getter(args) + return NotImplemented + def _is_parameterized_(self) -> bool: - return protocols.is_parameterized(self.gate) + getter = getattr(self.gate, '_is_parameterized_', None) + if getter is not None: + return getter() + return NotImplemented def _resolve_parameters_(self, resolver): resolved_gate = protocols.resolve_parameters(self.gate, resolver) - return GateOperation(resolved_gate, self._qubits) + return self.with_gate(resolved_gate) def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs' ) -> 'cirq.CircuitDiagramInfo': @@ -174,7 +209,10 @@ def _decompose_into_clifford_(self): return sub(self.qubits) def _trace_distance_bound_(self) -> float: - return protocols.trace_distance_bound(self.gate) + getter = getattr(self.gate, '_trace_distance_bound_', None) + if getter is not None: + return getter() + return NotImplemented def _phase_by_(self, phase_turns: float, qubit_index: int) -> 'GateOperation': diff --git a/cirq/ops/gate_operation_test.py b/cirq/ops/gate_operation_test.py index c0ba51e6bd5..53cee754f54 100644 --- a/cirq/ops/gate_operation_test.py +++ b/cirq/ops/gate_operation_test.py @@ -207,6 +207,22 @@ def test_pauli_expansion(): assert (cirq.pauli_expansion(cirq.CNOT(a, b)) == cirq.pauli_expansion( cirq.CNOT)) + class No(cirq.Gate): + + def num_qubits(self) -> int: + return 1 + + class Yes(cirq.Gate): + + def num_qubits(self) -> int: + return 1 + + def _pauli_expansion_(self): + return cirq.LinearDict({'X': 0.5}) + + assert cirq.pauli_expansion(No().on(a), default=None) is None + assert cirq.pauli_expansion(Yes().on(a)) == cirq.LinearDict({'X': 0.5}) + def test_unitary(): a = cirq.NamedQubit('a') @@ -344,3 +360,39 @@ def _mul_with_qubits(self, qubits, other): # Handles the symmetric type case correctly. assert m * m == 6 assert r * r == 4 + + +def test_with_gate(): + g1 = cirq.GateOperation(cirq.X, cirq.LineQubit.range(1)) + g2 = cirq.GateOperation(cirq.Y, cirq.LineQubit.range(1)) + assert g1.with_gate(cirq.X) is g1 + assert g1.with_gate(cirq.Y) == g2 + + +def test_is_parameterized(): + + class No1(cirq.Gate): + + def num_qubits(self) -> int: + return 1 + + class No2(cirq.Gate): + + def num_qubits(self) -> int: + return 1 + + def _is_parameterized_(self): + return False + + class Yes(cirq.Gate): + + def num_qubits(self) -> int: + return 1 + + def _is_parameterized_(self): + return True + + q = cirq.LineQubit(0) + assert not cirq.is_parameterized(No1().on(q)) + assert not cirq.is_parameterized(No2().on(q)) + assert cirq.is_parameterized(Yes().on(q)) diff --git a/cirq/ops/measurement_gate.py b/cirq/ops/measurement_gate.py index 05e3c1cd828..1b1248abce4 100644 --- a/cirq/ops/measurement_gate.py +++ b/cirq/ops/measurement_gate.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Iterable, Optional, Tuple, Sequence, TYPE_CHECKING +from typing import Any, Dict, Iterable, Optional, Tuple, Sequence, \ + TYPE_CHECKING, Union import numpy as np @@ -215,6 +216,28 @@ def _from_json_dict_(cls, invert_mask=tuple(invert_mask), qid_shape=None if qid_shape is None else tuple(qid_shape)) + def _act_on_(self, args: Any) -> bool: + from cirq import sim + + if isinstance(args, sim.ActOnStateVectorArgs): + + invert_mask = self.full_invert_mask() + bits, _ = sim.measure_state_vector( + args.target_tensor, + args.axes, + out=args.target_tensor, + qid_shape=args.target_tensor.shape, + seed=args.prng) + corrected = [ + bit ^ (bit < 2 and mask) + for bit, mask in zip(bits, invert_mask) + ] + args.record_measurement_result(self.key, corrected) + + return True + + return NotImplemented + def _default_measurement_key(qubits: Iterable[raw_types.Qid]) -> str: return ','.join(str(q) for q in qubits) diff --git a/cirq/protocols/__init__.py b/cirq/protocols/__init__.py index bbf9e8d3585..ccd5d295940 100644 --- a/cirq/protocols/__init__.py +++ b/cirq/protocols/__init__.py @@ -13,6 +13,10 @@ # limitations under the License. +from cirq.protocols.act_on_protocol import ( + act_on, + SupportsActOn, +) from cirq.protocols.apply_unitary_protocol import ( apply_unitaries, apply_unitary, diff --git a/cirq/protocols/act_on_protocol.py b/cirq/protocols/act_on_protocol.py new file mode 100644 index 00000000000..5b61fdcb24e --- /dev/null +++ b/cirq/protocols/act_on_protocol.py @@ -0,0 +1,121 @@ +# Copyright 2018 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A protocol that wouldn't exist if python had __rimul__.""" + +from typing import (Any, TYPE_CHECKING, Union) + +from typing_extensions import Protocol + +from cirq._doc import document +from cirq.type_workarounds import NotImplementedType + +if TYPE_CHECKING: + pass + + +class SupportsActOn(Protocol): + """An object that explicitly specifies how to act on simulator states.""" + + @document + def _act_on_(self, args: Any) -> Union[NotImplementedType, bool]: + """Applies an action to the given argument, if it is a supported type. + + For example, unitary operations can implement an `_act_on_` method that + checks if `isinstance(args, cirq.ActOnStateVectorArgs)` and, if so, + apply their unitary effect to the state vector. + + The global `cirq.act_on` method looks for whether or not the given + argument has this value, before attempting any fallback strategies + specified by the argument being acted on. + + This method is analogous to python's `__imul__` in that it is expected + to perform an inline effect if it recognizes the type of an argument, + and return NotImplemented otherwise. It is also analogous to python's + `__rmul__` in that dispatch is being done on the right hand side value + instead of the left hand side value. If python had an `__rimul__` + method, then `_act_on_` would not exist because it would be redundant. + + Args: + args: An object of unspecified type. The method must check if this + object is of a recognized type and act on it if so. + + Returns: + True: The receiving object (`self`) acted on the argument. + NotImplemented: The receiving object did not act on the argument. + + All other return values are considered to be errors. + """ + + +def act_on( + action: Any, + args: Any, + *, + allow_decompose: bool = True, +): + """Applies an action to a state argument. + + For example, the action may be a `cirq.Operation` and the state argument may + represent the internal state of a state vector simulator (a + `cirq.ActOnStateVectorArgs`). + + The action is applied by first checking if `action._act_on_` exists and + returns `True` (instead of `NotImplemented`) for the given object. Then + fallback strategies specified by the state argument via `_act_on_fallback_` + are attempted. If those also fail, the method fails with a `TypeError`. + + Args: + action: The action to apply to the state tensor. Typically a + `cirq.Operation`. + args: A mutable state object that should be modified by the action. May + specify an `_act_on_fallback_` method to use in case the action + doesn't recognize it. + allow_decompose: Defaults to True. Forwarded into the + `_act_on_fallback_` method of `args`. Determines if decomposition + should be used or avoided when attempting to act `action` on `args`. + Used by internal methods to avoid redundant decompositions. + + Returns: + Nothing. Results are communicated by editing `args`. + + Raises: + TypeError: Failed to act `action` on `args`. + """ + + action_act_on = getattr(action, '_act_on_', None) + if action_act_on is not None: + result = action_act_on(args) + if result is True: + return + if result is not NotImplemented: + raise ValueError( + f'_act_on_ must return True or NotImplemented but got ' + f'{result!r} from {action!r}._act_on_') + + arg_fallback = getattr(args, '_act_on_fallback_') + if arg_fallback is not None: + result = arg_fallback(action, allow_decompose=allow_decompose) + if result is True: + return + if result is not NotImplemented: + raise ValueError( + f'_act_on_fallback_ must return True or NotImplemented but got ' + f'{result!r} from {type(args)}._act_on_fallback_') + + raise TypeError("Failed to act action on state argument.\n" + "Tried both action._act_on_ and args._act_on_fallback_.\n" + "\n" + f"State argument type: {type(args)}\n" + f"Action type: {type(action)}\n" + f"Action repr: {action!r}\n") diff --git a/cirq/protocols/apply_unitary_protocol.py b/cirq/protocols/apply_unitary_protocol.py index 986110e6868..9513e6f7071 100644 --- a/cirq/protocols/apply_unitary_protocol.py +++ b/cirq/protocols/apply_unitary_protocol.py @@ -215,7 +215,8 @@ def subspace_index(self, return linalg.slice_for_qubits_equal_to( self.axes, little_endian_qureg_value=little_endian_bits_int, - big_endian_qureg_value=big_endian_bits_int) + big_endian_qureg_value=big_endian_bits_int, + qid_shape=self.target_tensor.shape) class SupportsConsistentApplyUnitary(Protocol): @@ -265,10 +266,13 @@ def _apply_unitary_(self, args: ApplyUnitaryArgs """ -def apply_unitary(unitary_value: Any, - args: ApplyUnitaryArgs, - default: TDefault = RaiseTypeErrorIfNotProvided - ) -> Union[np.ndarray, TDefault]: +def apply_unitary( + unitary_value: Any, + args: ApplyUnitaryArgs, + default: TDefault = RaiseTypeErrorIfNotProvided, + *, + allow_decompose: bool = True, +) -> Union[np.ndarray, TDefault]: """High performance left-multiplication of a unitary effect onto a tensor. Applies the unitary effect of `unitary_value` to the tensor specified in @@ -290,7 +294,7 @@ def apply_unitary(unitary_value: Any, Case c) Method returns a numpy array. Multiply the matrix onto the target tensor and return to the caller. - C. Try to use `unitary_value._decompose_()`. + C. Try to use `unitary_value._decompose_()` (if `allow_decompose`). Case a) Method not present or returns `NotImplemented` or `None`. Continue to next strategy. Case b) Method returns an OP_TREE. @@ -311,6 +315,9 @@ def apply_unitary(unitary_value: Any, default: What should be returned if `unitary_value` doesn't have a unitary effect. If not specified, a TypeError is raised instead of returning a default value. + allow_decompose: Defaults to True. If set to False, and applying the + unitary effect requires decomposing the object, the method will + pretend the object has no unitary effect. Returns: If the receiving object does not have a unitary effect, then the @@ -341,6 +348,8 @@ def apply_unitary(unitary_value: Any, _strat_apply_unitary_from_decompose, _strat_apply_unitary_from_unitary ] + if not allow_decompose: + strats.remove(_strat_apply_unitary_from_decompose) # Try each strategy, stopping if one works. for strat in strats: diff --git a/cirq/protocols/apply_unitary_protocol_test.py b/cirq/protocols/apply_unitary_protocol_test.py index 2fc5e95940f..ba5d5960ed3 100644 --- a/cirq/protocols/apply_unitary_protocol_test.py +++ b/cirq/protocols/apply_unitary_protocol_test.py @@ -275,8 +275,10 @@ def test_big_endian_subspace_index(): state = np.zeros(shape=(2, 3, 4, 5, 1, 6, 1, 1)) args = cirq.ApplyUnitaryArgs(state, np.empty_like(state), [1, 3]) s = slice(None) - assert args.subspace_index(little_endian_bits_int=1) == (s, 1, s, 0, ...) - assert args.subspace_index(big_endian_bits_int=1) == (s, 0, s, 1, ...) + assert args.subspace_index(little_endian_bits_int=1) == (s, 1, s, 0, s, s, + s, s) + assert args.subspace_index(big_endian_bits_int=1) == (s, 0, s, 1, s, s, s, + s) def test_apply_unitaries(): diff --git a/cirq/protocols/has_unitary_protocol.py b/cirq/protocols/has_unitary_protocol.py index 98f5c818138..048fb6959d3 100644 --- a/cirq/protocols/has_unitary_protocol.py +++ b/cirq/protocols/has_unitary_protocol.py @@ -52,7 +52,7 @@ def _has_unitary_(self) -> bool: """ -def has_unitary(val: Any) -> bool: +def has_unitary(val: Any, *, allow_decompose: bool = True) -> bool: """Determines whether the value has a unitary effect. Determines whether `val` has a unitary effect by attempting the following @@ -104,6 +104,8 @@ def has_unitary(val: Any) -> bool: _strat_has_unitary_from_has_unitary, _strat_has_unitary_from_decompose, _strat_has_unitary_from_apply_unitary, _strat_has_unitary_from_unitary ] + if not allow_decompose: + strats.remove(_strat_has_unitary_from_decompose) for strat in strats: result = strat(val) if result is not None: diff --git a/cirq/protocols/has_unitary_protocol_test.py b/cirq/protocols/has_unitary_protocol_test.py index ed2381c9487..b636b790c55 100644 --- a/cirq/protocols/has_unitary_protocol_test.py +++ b/cirq/protocols/has_unitary_protocol_test.py @@ -47,6 +47,7 @@ def _unitary_(self): assert not cirq.has_unitary(No1()) assert not cirq.has_unitary(No2()) assert cirq.has_unitary(Yes()) + assert cirq.has_unitary(Yes(), allow_decompose=False) def test_via_apply_unitary(): @@ -82,6 +83,7 @@ def _apply_unitary_(self, args): return args.target_tensor assert cirq.has_unitary(Yes1()) + assert cirq.has_unitary(Yes1(), allow_decompose=False) assert cirq.has_unitary(Yes2()) assert not cirq.has_unitary(No1()) assert not cirq.has_unitary(No2()) @@ -122,6 +124,10 @@ def _decompose_(self): assert not cirq.has_unitary(No2()) assert not cirq.has_unitary(No3()) + assert not cirq.has_unitary(Yes1(), allow_decompose=False) + assert not cirq.has_unitary(Yes2(), allow_decompose=False) + assert not cirq.has_unitary(No1(), allow_decompose=False) + def test_via_has_unitary(): diff --git a/cirq/protocols/measurement_key_protocol.py b/cirq/protocols/measurement_key_protocol.py index 064bdaf623b..eea88b01195 100644 --- a/cirq/protocols/measurement_key_protocol.py +++ b/cirq/protocols/measurement_key_protocol.py @@ -109,7 +109,8 @@ def measurement_keys(val: Any, *, don't directly specify their measurement keys will be decomposed in order to find measurement keys within the decomposed operations. If not set, composite operations will appear to have no measurement - keys. + keys. Used by internal methods to stop redundant decompositions from + being performed. Returns: The measurement keys of the value. If the value has no measurement, diff --git a/cirq/protocols/mixture_protocol.py b/cirq/protocols/mixture_protocol.py index bbdf2874308..edc92713586 100644 --- a/cirq/protocols/mixture_protocol.py +++ b/cirq/protocols/mixture_protocol.py @@ -19,6 +19,8 @@ from typing_extensions import Protocol from cirq._doc import document +from cirq.protocols.decompose_protocol import \ + _try_decompose_into_operations_and_qubits from cirq.protocols.has_unitary_protocol import has_unitary from cirq.type_workarounds import NotImplementedType @@ -161,12 +163,22 @@ def mixture_channel(val: Any, default: Any = RaiseTypeErrorIfNotProvided "method, but it returned NotImplemented.".format(type(val))) -def has_mixture_channel(val: Any) -> bool: +def has_mixture_channel(val: Any, *, allow_decompose: bool = True) -> bool: """Returns whether the value has a mixture channel representation. In contrast to `has_mixture` this method falls back to checking whether the value has a unitary representation via `has_channel`. + Args: + val: The value to check. + allow_decompose: Used by internal methods to stop redundant + decompositions from being performed (e.g. there's no need to + decompose an object to check if it is unitary as part of determining + if the object is a quantum channel, when the quantum channel check + will already be doing a more general decomposition check). Defaults + to True. When false, the decomposition strategy for determining + the result is skipped. + Returns: If `val` has a `_has_mixture_` method and its result is not NotImplemented, that result is returned. Otherwise, if `val` has a @@ -180,9 +192,13 @@ def has_mixture_channel(val: Any) -> bool: if result is not NotImplemented: return result - result = has_unitary(val) - if result is not NotImplemented and result: - return result + if has_unitary(val, allow_decompose=False): + return True + + if allow_decompose: + operations, _, _ = _try_decompose_into_operations_and_qubits(val) + if operations is not None: + return all(has_mixture_channel(val) for val in operations) # No _has_mixture_ or _has_unitary_ function, use _mixture_ instead. return mixture_channel(val, None) is not None diff --git a/cirq/protocols/mixture_protocol_test.py b/cirq/protocols/mixture_protocol_test.py index efef5512bb2..7b14fcd8365 100644 --- a/cirq/protocols/mixture_protocol_test.py +++ b/cirq/protocols/mixture_protocol_test.py @@ -147,6 +147,28 @@ def test_has_mixture_channel(): assert cirq.has_mixture_channel(ReturnsUnitary()) assert not cirq.has_mixture_channel(ReturnsNotImplementedUnitary()) + class NoAtom(cirq.Operation): + + @property + def qubits(self): + return cirq.LineQubit.range(2) + + def with_qubits(self): + raise NotImplementedError() + + class No1: + + def _decompose_(self): + return [NoAtom()] + + class Yes1: + + def _decompose_(self): + return [cirq.X(cirq.LineQubit(0))] + + assert not cirq.has_mixture_channel(No1()) + assert cirq.has_mixture_channel(Yes1()) + def test_valid_mixture(): cirq.validate_mixture(ReturnsValidTuple()) diff --git a/cirq/sim/__init__.py b/cirq/sim/__init__.py index d6d379c89b4..bdf79d4a14d 100644 --- a/cirq/sim/__init__.py +++ b/cirq/sim/__init__.py @@ -14,6 +14,9 @@ """Base simulation classes and generic simulators.""" +from cirq.sim.act_on_state_vector_args import ( + ActOnStateVectorArgs,) + from cirq.sim.density_matrix_utils import ( measure_density_matrix, sample_density_matrix, diff --git a/cirq/sim/act_on_state_vector_args.py b/cirq/sim/act_on_state_vector_args.py new file mode 100644 index 00000000000..c34e4cac7d1 --- /dev/null +++ b/cirq/sim/act_on_state_vector_args.py @@ -0,0 +1,212 @@ +# Copyright 2018 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Objects and methods for acting efficiently on a state vector.""" + +from typing import (Any, Iterable, Sequence, Tuple, TYPE_CHECKING, Union, + DefaultDict, List, Dict) + +import numpy as np + +from cirq import linalg, protocols +from cirq.protocols.decompose_protocol import ( + _try_decompose_into_operations_and_qubits,) + +if TYPE_CHECKING: + import cirq + + +class ActOnStateVectorArgs: + """State and context for an operation acting on a state vector.""" + + def __init__(self, target_tensor: np.ndarray, available_buffer: np.ndarray, + axes: Iterable[int], prng: np.random.RandomState, + log_of_measurement_results: Dict[str, Any]): + """ + Args: + target_tensor: The state vector to act on, stored as a numpy array + with one dimension for each qubit in the system. Operations are + expected to perform inplace edits of this object. + available_buffer: A workspace with the same shape and dtype as + `target_tensor`. The result of an operation can be put into this + buffer, instead of directly editing `target_tensor`, if + `swap_target_tensor_for` is called afterward. + axes: The indices of axes corresponding to the qubits that the + operation is supposed to act upon. + prng: The pseudo random number generator to use for probabilistic + effects. + log_of_measurement_results: A mutable object that measurements are + being recorded into. Edit it easily by calling + `ActOnStateVectorArgs.record_measurement_result`. + """ + self.target_tensor = target_tensor + self.available_buffer = available_buffer + self.axes = tuple(axes) + self.prng = prng + self.log_of_measurement_results = log_of_measurement_results + + def swap_target_tensor_for(self, new_target_tensor: np.ndarray): + """Gives a new state vector for the system. + + Typically, the new state vector should be `args.available_buffer` where + `args` is this `cirq.ActOnStateVectorArgs` instance. + + Args: + new_target_tensor: The new system state. Must have the same shape + and dtype as the old system state. + """ + if new_target_tensor is self.available_buffer: + self.available_buffer = self.target_tensor + self.target_tensor = new_target_tensor + + def record_measurement_result(self, key: str, value: Any): + """Adds a measurement result to the log. + + Args: + key: The key the measurement result should be logged under. Note + that operations should only store results under keys they have + declared in a `_measurement_keys_` method. + value: The value to log for the measurement. + """ + if key in self.log_of_measurement_results: + raise ValueError(f"Measurement already logged to key {key!r}") + self.log_of_measurement_results[key] = value + + def subspace_index(self, + little_endian_bits_int: int = 0, + *, + big_endian_bits_int: int = 0 + ) -> Tuple[Union[slice, int, 'ellipsis'], ...]: + """An index for the subspace where the target axes equal a value. + + Args: + little_endian_bits_int: The desired value of the qubits at the + targeted `axes`, packed into an integer. The least significant + bit of the integer is the desired bit for the first axis, and + so forth in increasing order. Can't be specified at the same + time as `big_endian_bits_int`. + big_endian_bits_int: The desired value of the qubits at the + targeted `axes`, packed into an integer. The most significant + bit of the integer is the desired bit for the first axis, and + so forth in decreasing order. Can't be specified at the same + time as `little_endian_bits_int`. + value_tuple: The desired value of the qids at the targeted `axes`, + packed into a tuple. Specify either `little_endian_bits_int` or + `value_tuple`. + + Returns: + A value that can be used to index into `target_tensor` and + `available_buffer`, and manipulate only the part of Hilbert space + corresponding to a given bit assignment. + + Example: + If `target_tensor` is a 4 qubit tensor and `axes` is `[1, 3]` and + then this method will return the following when given + `little_endian_bits=0b01`: + + `(slice(None), 0, slice(None), 1, Ellipsis)` + + Therefore the following two lines would be equivalent: + + args.target_tensor[args.subspace_index(0b01)] += 1 + + args.target_tensor[:, 0, :, 1] += 1 + """ + return linalg.slice_for_qubits_equal_to( + self.axes, + little_endian_qureg_value=little_endian_bits_int, + big_endian_qureg_value=big_endian_bits_int, + qid_shape=self.target_tensor.shape) + + def _act_on_fallback_(self, action: Any, allow_decompose: bool): + strats = [ + _strat_act_on_state_vector_from_apply_unitary, + _strat_act_on_state_vector_from_mixture, + ] + if allow_decompose: + strats.append(_strat_act_on_state_vector_from_apply_decompose) + + # Try each strategy, stopping if one works. + for strat in strats: + result = strat(action, self) + if result is False: + break + if result is True: + return True + assert result is NotImplemented + + return NotImplemented + + +def _strat_act_on_state_vector_from_apply_unitary( + unitary_value: Any, + args: 'cirq.ActOnStateVectorArgs', +) -> bool: + new_target_tensor = protocols.apply_unitary( + unitary_value, + protocols.ApplyUnitaryArgs( + target_tensor=args.target_tensor, + available_buffer=args.available_buffer, + axes=args.axes, + ), + allow_decompose=False, + default=NotImplemented) + if new_target_tensor is NotImplemented: + return NotImplemented + args.swap_target_tensor_for(new_target_tensor) + return True + + +def _strat_act_on_state_vector_from_apply_decompose( + val: Any, + args: ActOnStateVectorArgs, +) -> bool: + operations, qubits, _ = _try_decompose_into_operations_and_qubits(val) + if operations is None: + return NotImplemented + return _act_all_on_state_vector(operations, qubits, args) + + +def _act_all_on_state_vector(actions: Iterable[Any], + qubits: Sequence['cirq.Qid'], + args: 'cirq.ActOnStateVectorArgs'): + if len(qubits) != len(args.axes): + raise ValueError('len(qubits) != len(args.axes)') + qubit_map = {q: args.axes[i] for i, q in enumerate(qubits)} + + old_indices = args.indices + try: + for action in actions: + args.indices = [qubit_map[q] for q in action.qubits] + protocols.act_on(action, args) + finally: + args.indices = old_indices + + +def _strat_act_on_state_vector_from_mixture(action: Any, + args: 'cirq.ActOnStateVectorArgs' + ) -> bool: + mixture = protocols.mixture(action, default=None) + if mixture is None: + return NotImplemented + probabilities, unitaries = zip(*mixture) + + index = args.prng.choice(range(len(unitaries)), p=probabilities) + shape = protocols.qid_shape(action) * 2 + unitary = unitaries[index].astype(args.target_tensor.dtype).reshape(shape) + linalg.targeted_left_multiply(unitary, + args.target_tensor, + args.axes, + out=args.available_buffer) + args.swap_target_tensor_for(args.available_buffer) + return True diff --git a/cirq/sim/clifford/clifford_simulator.py b/cirq/sim/clifford/clifford_simulator.py index 00f713e64fc..03907bde000 100644 --- a/cirq/sim/clifford/clifford_simulator.py +++ b/cirq/sim/clifford/clifford_simulator.py @@ -61,7 +61,7 @@ def __init__(self, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None): @staticmethod def is_supported_operation(op: 'cirq.Operation') -> bool: """Checks whether given operation can be simulated by this simulator.""" - if protocols.is_measurement(op): return True + if isinstance(op.gate, cirq.MeasurementGate): return True if isinstance(op, GlobalPhaseOperation): return True if not protocols.has_unitary(op): return False u = cirq.unitary(op) @@ -100,16 +100,20 @@ def _base_iterator(self, circuit: circuits.Circuit, state = CliffordState(qubit_map, initial_state=initial_state) for moment in circuit: - measurements = collections.defaultdict( - list) # type: Dict[str, List[np.ndarray]] + measurements: Dict[str, List[ + np.ndarray]] = collections.defaultdict(list) for op in moment: - if protocols.has_unitary(op): - state.apply_unitary(op) - elif protocols.is_measurement(op): + if protocols.is_measurement(op): + if not isinstance(op.gate, ops.MeasurementGate): + raise NotImplementedError( + 'Measurement type other than cirq.MeasurementGate' + ) key = protocols.measurement_key(op) measurements[key].extend( state.perform_measurement(op.qubits, self._prng)) + elif protocols.has_unitary(op): + state.apply_unitary(op) yield CliffordSimulatorStepResult(measurements=measurements, state=state) diff --git a/cirq/sim/simulator.py b/cirq/sim/simulator.py index 2225051807e..d459e91f70e 100644 --- a/cirq/sim/simulator.py +++ b/cirq/sim/simulator.py @@ -578,8 +578,8 @@ def _qubit_map_to_shape(qubit_map: Dict[ops.Qid, int]) -> Tuple[int, ...]: def _verify_unique_measurement_keys(circuit: circuits.Circuit): result = collections.Counter( - protocols.measurement_key(op, default=None) - for op in ops.flatten_op_tree(iter(circuit))) + key for op in ops.flatten_op_tree(iter(circuit)) + for key in protocols.measurement_keys(op)) result[None] = 0 duplicates = [k for k, v in result.most_common() if v > 1] if duplicates: diff --git a/cirq/sim/sparse_simulator.py b/cirq/sim/sparse_simulator.py index 899a539a523..cb1bcfe3a7f 100644 --- a/cirq/sim/sparse_simulator.py +++ b/cirq/sim/sparse_simulator.py @@ -15,43 +15,26 @@ """A simulator that uses numpy's einsum or sparse matrix operations.""" import collections - -from typing import Dict, Iterator, List, Tuple, Type, TYPE_CHECKING +from typing import Dict, Iterator, List, Type, TYPE_CHECKING import numpy as np -from cirq import circuits, linalg, ops, protocols, qis, study, value -from cirq.sim import simulator, wave_function, wave_function_simulator +from cirq import circuits, ops, protocols, qis, study, value +from cirq.sim import ( + simulator, + wave_function, + wave_function_simulator, + act_on_state_vector_args, +) if TYPE_CHECKING: import cirq -class _FlipGate(ops.SingleQubitGate): - """A unitary gate that flips the |0> state with another state. - - Used by `Simulator` to reset a qubit. - """ - - def __init__(self, dimension: int, reset_value: int): - assert 0 < reset_value < dimension - self.dimension = dimension - self.reset_value = reset_value - - def _qid_shape_(self) -> Tuple[int, ...]: - return (self.dimension,) - - def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> np.ndarray: - args.available_buffer[..., 0] = args.target_tensor[..., self. - reset_value] - args.available_buffer[..., self. - reset_value] = args.target_tensor[..., 0] - return args.available_buffer - - # Mutable named tuple to hold state and a buffer. -class _StateAndBuffer(): - def __init__(self, state, buffer): +class _StateAndBuffer: + + def __init__(self, state: np.ndarray, buffer: np.ndarray): self.state = state self.buffer = buffer @@ -152,11 +135,9 @@ def __init__(self, self._dtype = dtype self._prng = value.parse_random_state(seed) - def _run( - self, - circuit: circuits.Circuit, - param_resolver: study.ParamResolver, - repetitions: int) -> Dict[str, List[np.ndarray]]: + def _run(self, circuit: circuits.Circuit, + param_resolver: study.ParamResolver, + repetitions: int) -> Dict[str, np.ndarray]: """See definition in `cirq.SimulatesSamples`.""" param_resolver = param_resolver or study.ParamResolver({}) resolved_circuit = protocols.resolve_parameters(circuit, param_resolver) @@ -165,13 +146,12 @@ def _run( def measure_or_mixture(op): return protocols.is_measurement(op) or protocols.has_mixture(op) if circuit.are_all_matches_terminal(measure_or_mixture): - return self._run_sweep_sample(resolved_circuit, repetitions) + return self._run_sweep_terminal_sample(resolved_circuit, + repetitions) return self._run_sweep_repeat(resolved_circuit, repetitions) - def _run_sweep_sample( - self, - circuit: circuits.Circuit, - repetitions: int) -> Dict[str, List[np.ndarray]]: + def _run_sweep_terminal_sample(self, circuit: circuits.Circuit, + repetitions: int) -> Dict[str, np.ndarray]: for step_result in self._base_iterator( circuit=circuit, qubit_order=ops.QubitOrder.DEFAULT, @@ -187,16 +167,15 @@ def _run_sweep_sample( repetitions, seed=self._prng) - def _run_sweep_repeat( - self, - circuit: circuits.Circuit, - repetitions: int) -> Dict[str, List[np.ndarray]]: - measurements = {} # type: Dict[str, List[np.ndarray]] + def _run_sweep_repeat(self, circuit: circuits.Circuit, + repetitions: int) -> Dict[str, np.ndarray]: if repetitions == 0: - for _, op, _ in circuit.findall_operations_with_gate_type( - ops.MeasurementGate): - measurements[protocols.measurement_key(op)] = np.empty([0, 1]) + return { + key: np.empty(shape=[0, 1]) + for key in protocols.measurement_keys(circuit) + } + measurements = collections.defaultdict(list) for _ in range(repetitions): all_step_results = self._base_iterator( circuit, @@ -205,8 +184,6 @@ def _run_sweep_repeat( for step_result in all_step_results: for k, v in step_result.measurements.items(): - if not k in measurements: - measurements[k] = [] measurements[k].append(np.array(v, dtype=np.uint8)) return {k: np.array(v) for k, v in measurements.items()} @@ -240,7 +217,7 @@ def _base_iterator( qubit_order: ops.QubitOrderOrList, initial_state: 'cirq.STATE_VECTOR_LIKE', perform_measurements: bool = True, - ) -> Iterator: + ) -> Iterator['SparseSimulatorStep']: qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for( circuit.all_qubits()) num_qubits = len(qubits) @@ -253,120 +230,26 @@ def _base_iterator( if len(circuit) == 0: yield SparseSimulatorStep(state, {}, qubit_map, self._dtype) - def on_stuck(bad_op: ops.Operation): - return TypeError( - "Can't simulate unknown operations that don't specify a " - "_unitary_ method, a _decompose_ method, " - "(_has_unitary_ + _apply_unitary_) methods," - "(_has_mixture_ + _mixture_) methods, or are measurements." - ": {!r}".format(bad_op)) - - def keep(potential_op: ops.Operation) -> bool: - # The order of this is optimized to call has_xxx methods first. - return (protocols.has_unitary(potential_op) or - protocols.has_mixture(potential_op) or - protocols.is_measurement(potential_op) or - isinstance(potential_op.gate, ops.ResetChannel)) - - data = _StateAndBuffer(state=np.reshape(state, qid_shape), - buffer=np.empty(qid_shape, dtype=self._dtype)) + sim_state = act_on_state_vector_args.ActOnStateVectorArgs( + target_tensor=np.reshape(state, qid_shape), + available_buffer=np.empty(qid_shape, dtype=self._dtype), + axes=[], + prng=self._prng, + log_of_measurement_results={}) + for moment in circuit: - measurements = collections.defaultdict( - list) # type: Dict[str, List[int]] - - unitary_ops_and_measurements = protocols.decompose( - moment, keep=keep, on_stuck_raise=on_stuck) - - for op in unitary_ops_and_measurements: - indices = [qubit_map[qubit] for qubit in op.qubits] - if isinstance(op.gate, ops.ResetChannel): - self._simulate_reset(op, data, indices) - elif protocols.has_unitary(op): - self._simulate_unitary(op, data, indices) - elif protocols.is_measurement(op): - # Do measurements second, since there may be mixtures that - # operate as measurements. - # TODO: support measurement outside the computational basis. - # Github issue: - # https://github.com/quantumlib/Cirq/issues/1357 - if perform_measurements: - self._simulate_measurement(op, data, indices, - measurements, num_qubits) - elif protocols.has_mixture(op): - self._simulate_mixture(op, data, indices) - - yield SparseSimulatorStep( - state_vector=data.state, - measurements=measurements, - qubit_map=qubit_map, - dtype=self._dtype) - - def _simulate_unitary(self, op: ops.Operation, data: _StateAndBuffer, - indices: List[int]) -> None: - """Simulate an op that has a unitary.""" - result = protocols.apply_unitary( - op, - args=protocols.ApplyUnitaryArgs( - data.state, - data.buffer, - indices)) - if result is data.buffer: - data.buffer = data.state - data.state = result - - def _simulate_reset(self, op: ops.Operation, data: _StateAndBuffer, - indices: List[int]) -> None: - """Simulate an op that is a reset to the |0> state.""" - if isinstance(op.gate, ops.ResetChannel): - reset = op.gate - # Do a silent measurement. - bits, _ = wave_function.measure_state_vector( - data.state, indices, out=data.state, qid_shape=data.state.shape) - # Apply bit flip(s) to change the reset the bits to 0. - for b, i, d in zip(bits, indices, protocols.qid_shape(reset)): - if b == 0: - continue # Already zero, no reset needed - reset_unitary = _FlipGate(d, reset_value=b)(*op.qubits) - self._simulate_unitary(reset_unitary, data, [i]) - - def _simulate_measurement(self, op: ops.Operation, data: _StateAndBuffer, - indices: List[int], - measurements: Dict[str, List[int]], - num_qubits: int) -> None: - """Simulate an op that is a measurement in the computational basis.""" - # TODO: support measurement outside computational basis. - # Github issue: https://github.com/quantumlib/Cirq/issues/1357 - if isinstance(op.gate, ops.MeasurementGate): - meas = op.gate - invert_mask = meas.full_invert_mask() - # Measure updates inline. - bits, _ = wave_function.measure_state_vector( - data.state, - indices, - out=data.state, - qid_shape=data.state.shape, - seed=self._prng) - corrected = [ - bit ^ (bit < 2 and mask) - for bit, mask in zip(bits, invert_mask) - ] - key = protocols.measurement_key(meas) - measurements[key].extend(corrected) - - def _simulate_mixture(self, op: ops.Operation, data: _StateAndBuffer, - indices: List[int]) -> None: - """Simulate an op that is a mixtures of unitaries.""" - probs, unitaries = zip(*protocols.mixture(op)) - # We work around numpy barfing on choosing from a list of - # numpy arrays (which is not `one-dimensional`) by selecting - # the index of the unitary. - index = self._prng.choice(range(len(unitaries)), p=probs) - shape = protocols.qid_shape(op) * 2 - unitary = unitaries[index].astype(self._dtype).reshape(shape) - result = linalg.targeted_left_multiply(unitary, data.state, indices, - out=data.buffer) - data.buffer = data.state - data.state = result + for op in moment: + if perform_measurements or not isinstance( + op.gate, ops.MeasurementGate): + sim_state.axes = [qubit_map[qubit] for qubit in op.qubits] + protocols.act_on(op, sim_state) + + yield SparseSimulatorStep(state_vector=sim_state.target_tensor, + measurements=dict( + sim_state.log_of_measurement_results), + qubit_map=qubit_map, + dtype=self._dtype) + sim_state.log_of_measurement_results.clear() def _check_all_resolved(self, circuit): """Raises if the circuit contains unresolved symbols.""" diff --git a/cirq/sim/sparse_simulator_test.py b/cirq/sim/sparse_simulator_test.py index 2402734646f..a6bdf91241d 100644 --- a/cirq/sim/sparse_simulator_test.py +++ b/cirq/sim/sparse_simulator_test.py @@ -104,6 +104,13 @@ def test_run_measure_at_end_no_repetitions(dtype): assert mock_sim.call_count == 4 +def test_run_repetitions_terminal_measurement_stochastic(): + q = cirq.LineQubit(0) + c = cirq.Circuit(cirq.H(q), cirq.measure(q, key='q')) + results = cirq.Simulator().run(c, repetitions=10000) + assert 1000 <= sum(v[0] for v in results.measurements['q']) < 9000 + + @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) def test_run_repetitions_measure_at_end(dtype): q0, q1 = cirq.LineQubit.range(2) diff --git a/docs/api.rst b/docs/api.rst index ed683610a09..30ef1fbdc9b 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -308,6 +308,7 @@ the magic methods that can be implemented. :toctree: generated/ cirq.DEFAULT_RESOLVERS + cirq.act_on cirq.apply_channel cirq.apply_mixture cirq.apply_unitaries @@ -355,6 +356,7 @@ the magic methods that can be implemented. cirq.QasmOutput cirq.QuilFormatter cirq.QuilOutput + cirq.SupportsActOn cirq.SupportsApplyChannel cirq.SupportsApplyMixture cirq.SupportsApproximateEquality From 87775f0a2006c8fc9c4d22e424276e1b7a87af81 Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Sun, 17 May 2020 19:50:45 -0700 Subject: [PATCH 2/6] Pass changed --- cirq/__init__.py | 3 ++- cirq/protocols/json_serialization_test.py | 8 +++++--- cirq/sim/act_on_state_vector_args.py | 9 +++++---- docs/api.rst | 1 + 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/cirq/__init__.py b/cirq/__init__.py index 879abd66789..220604d45f7 100644 --- a/cirq/__init__.py +++ b/cirq/__init__.py @@ -440,10 +440,11 @@ QuilFormatter, read_json, resolve_parameters, + SupportsActOn, SupportsApplyChannel, SupportsApplyMixture, - SupportsConsistentApplyUnitary, SupportsApproximateEquality, + SupportsConsistentApplyUnitary, SupportsChannel, SupportsCircuitDiagramInfo, SupportsCommutes, diff --git a/cirq/protocols/json_serialization_test.py b/cirq/protocols/json_serialization_test.py index 2991f8eb4ab..fd540074bc3 100644 --- a/cirq/protocols/json_serialization_test.py +++ b/cirq/protocols/json_serialization_test.py @@ -111,6 +111,11 @@ def test_fail_to_resolve(): # cirq.Circuit(cirq.rx(sympy.Symbol('theta')).on(Q0)), SHOULDNT_BE_SERIALIZED = [ + # Intermediate states with work buffers and unknown external prng guts. + 'ActOnStateVectorArgs', + 'ApplyChannelArgs', + 'ApplyMixtureArgs', + 'ApplyUnitaryArgs', # Circuit optimizers are function-like. Only attributes # are ignore_failures, tolerance, and other feature flags @@ -257,9 +262,6 @@ def test_mutually_exclusive_blacklist(): NOT_YET_SERIALIZABLE = [ - 'ApplyChannelArgs', - 'ApplyMixtureArgs', - 'ApplyUnitaryArgs', 'AsymmetricDepolarizingChannel', 'AxisAngleDecomposition', 'Calibration', diff --git a/cirq/sim/act_on_state_vector_args.py b/cirq/sim/act_on_state_vector_args.py index c34e4cac7d1..0c1ad779252 100644 --- a/cirq/sim/act_on_state_vector_args.py +++ b/cirq/sim/act_on_state_vector_args.py @@ -143,7 +143,7 @@ def _act_on_fallback_(self, action: Any, allow_decompose: bool): break if result is True: return True - assert result is NotImplemented + assert result is NotImplemented, str(result) return NotImplemented @@ -184,13 +184,14 @@ def _act_all_on_state_vector(actions: Iterable[Any], raise ValueError('len(qubits) != len(args.axes)') qubit_map = {q: args.axes[i] for i, q in enumerate(qubits)} - old_indices = args.indices + old_axes = args.axes try: for action in actions: - args.indices = [qubit_map[q] for q in action.qubits] + args.axes = [qubit_map[q] for q in action.qubits] protocols.act_on(action, args) finally: - args.indices = old_indices + args.axes = old_axes + return True def _strat_act_on_state_vector_from_mixture(action: Any, diff --git a/docs/api.rst b/docs/api.rst index 30ef1fbdc9b..17c49b73ed1 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -240,6 +240,7 @@ results. cirq.validate_mixture cirq.validate_probability cirq.xeb_fidelity + cirq.ActOnStateVectorArgs cirq.CircuitSampleJob cirq.CliffordSimulator cirq.CliffordSimulatorStepResult From 175b6740134d78895f657663c59854737d3a80b3 Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Sun, 17 May 2020 19:55:38 -0700 Subject: [PATCH 3/6] typecheck, lint --- cirq/ops/common_channels.py | 2 +- cirq/ops/gate_operation.py | 2 +- cirq/qis/states.py | 2 +- cirq/sim/act_on_state_vector_args.py | 5 ++--- cirq/sim/clifford/clifford_simulator.py | 3 +-- cirq/sim/simulator.py | 10 +++++----- cirq/sim/sparse_simulator.py | 8 +++++--- cirq/sim/wave_function.py | 12 +++--------- 8 files changed, 19 insertions(+), 25 deletions(-) diff --git a/cirq/ops/common_channels.py b/cirq/ops/common_channels.py index 024adec42ec..987bd433cce 100644 --- a/cirq/ops/common_channels.py +++ b/cirq/ops/common_channels.py @@ -19,7 +19,7 @@ import numpy as np -from cirq import protocols, value, linalg +from cirq import protocols, value from cirq.ops import (raw_types, common_gates, pauli_gates, gate_features, identity) diff --git a/cirq/ops/gate_operation.py b/cirq/ops/gate_operation.py index ecdda3632ef..a3f8f733419 100644 --- a/cirq/ops/gate_operation.py +++ b/cirq/ops/gate_operation.py @@ -16,7 +16,7 @@ import re from typing import (Any, Dict, FrozenSet, Iterable, List, Optional, Sequence, - Tuple, TypeVar, Union, TYPE_CHECKING, Type) + Tuple, TypeVar, Union, TYPE_CHECKING) import numpy as np diff --git a/cirq/qis/states.py b/cirq/qis/states.py index 88319bb9742..f343492b801 100644 --- a/cirq/qis/states.py +++ b/cirq/qis/states.py @@ -392,7 +392,7 @@ def validate_qid_shape(state: np.ndarray, return qid_shape -def validate_indices(num_qubits: int, indices: List[int]) -> None: +def validate_indices(num_qubits: int, indices: Sequence[int]) -> None: """Validates that the indices have values within range of num_qubits.""" if any(index < 0 for index in indices): raise IndexError('Negative index in indices: {}'.format(indices)) diff --git a/cirq/sim/act_on_state_vector_args.py b/cirq/sim/act_on_state_vector_args.py index 0c1ad779252..b57429d8d62 100644 --- a/cirq/sim/act_on_state_vector_args.py +++ b/cirq/sim/act_on_state_vector_args.py @@ -13,8 +13,7 @@ # limitations under the License. """Objects and methods for acting efficiently on a state vector.""" -from typing import (Any, Iterable, Sequence, Tuple, TYPE_CHECKING, Union, - DefaultDict, List, Dict) +from typing import Any, Iterable, Sequence, Tuple, TYPE_CHECKING, Union, Dict import numpy as np @@ -187,7 +186,7 @@ def _act_all_on_state_vector(actions: Iterable[Any], old_axes = args.axes try: for action in actions: - args.axes = [qubit_map[q] for q in action.qubits] + args.axes = tuple(qubit_map[q] for q in action.qubits) protocols.act_on(action, args) finally: args.axes = old_axes diff --git a/cirq/sim/clifford/clifford_simulator.py b/cirq/sim/clifford/clifford_simulator.py index 03907bde000..d62656bb13f 100644 --- a/cirq/sim/clifford/clifford_simulator.py +++ b/cirq/sim/clifford/clifford_simulator.py @@ -107,8 +107,7 @@ def _base_iterator(self, circuit: circuits.Circuit, if protocols.is_measurement(op): if not isinstance(op.gate, ops.MeasurementGate): raise NotImplementedError( - 'Measurement type other than cirq.MeasurementGate' - ) + f'Unrecognized measurement type {op!r}') key = protocols.measurement_key(op) measurements[key].extend( state.perform_measurement(op.qubits, self._prng)) diff --git a/cirq/sim/simulator.py b/cirq/sim/simulator.py index d459e91f70e..c56e608fcbc 100644 --- a/cirq/sim/simulator.py +++ b/cirq/sim/simulator.py @@ -580,8 +580,8 @@ def _verify_unique_measurement_keys(circuit: circuits.Circuit): result = collections.Counter( key for op in ops.flatten_op_tree(iter(circuit)) for key in protocols.measurement_keys(op)) - result[None] = 0 - duplicates = [k for k, v in result.most_common() if v > 1] - if duplicates: - raise ValueError('Measurement key {} repeated'.format( - ",".join(duplicates))) + if result: + duplicates = [k for k, v in result.most_common() if v > 1] + if duplicates: + raise ValueError('Measurement key {} repeated'.format( + ",".join(duplicates))) diff --git a/cirq/sim/sparse_simulator.py b/cirq/sim/sparse_simulator.py index cb1bcfe3a7f..6996c4461e0 100644 --- a/cirq/sim/sparse_simulator.py +++ b/cirq/sim/sparse_simulator.py @@ -15,7 +15,7 @@ """A simulator that uses numpy's einsum or sparse matrix operations.""" import collections -from typing import Dict, Iterator, List, Type, TYPE_CHECKING +from typing import Dict, Iterator, List, Type, TYPE_CHECKING, DefaultDict import numpy as np @@ -175,7 +175,8 @@ def _run_sweep_repeat(self, circuit: circuits.Circuit, for key in protocols.measurement_keys(circuit) } - measurements = collections.defaultdict(list) + measurements: DefaultDict[str, List[ + np.ndarray]] = collections.defaultdict(list) for _ in range(repetitions): all_step_results = self._base_iterator( circuit, @@ -241,7 +242,8 @@ def _base_iterator( for op in moment: if perform_measurements or not isinstance( op.gate, ops.MeasurementGate): - sim_state.axes = [qubit_map[qubit] for qubit in op.qubits] + sim_state.axes = tuple( + qubit_map[qubit] for qubit in op.qubits) protocols.act_on(op, sim_state) yield SparseSimulatorStep(state_vector=sim_state.target_tensor, diff --git a/cirq/sim/wave_function.py b/cirq/sim/wave_function.py index b4c2e590b6a..ec0b7dbb065 100644 --- a/cirq/sim/wave_function.py +++ b/cirq/sim/wave_function.py @@ -13,13 +13,7 @@ # limitations under the License. """Helpers for handling quantum wavefunctions.""" -from typing import ( - Dict, - List, - Optional, - Tuple, - TYPE_CHECKING, -) +from typing import (Dict, List, Optional, Tuple, TYPE_CHECKING, Sequence) import abc import numpy as np @@ -248,7 +242,7 @@ def sample_state_vector( def measure_state_vector( state: np.ndarray, - indices: List[int], + indices: Sequence[int], *, # Force keyword args qid_shape: Optional[Tuple[int, ...]] = None, out: np.ndarray = None, @@ -337,7 +331,7 @@ def measure_state_vector( return measurement_bits, out -def _probs(state: np.ndarray, indices: List[int], +def _probs(state: np.ndarray, indices: Sequence[int], qid_shape: Tuple[int, ...]) -> np.ndarray: """Returns the probabilities for a measurement on the given indices.""" tensor = np.reshape(state, qid_shape) From 0c84c9d1020c7ab8add840a99af46357537a78bf Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Tue, 19 May 2020 17:18:29 -0700 Subject: [PATCH 4/6] comments --- cirq/ops/measurement_gate_test.py | 80 ++++++++++++++++++++++++ cirq/protocols/act_on_protocol.py | 2 +- cirq/protocols/apply_unitary_protocol.py | 3 - cirq/sim/act_on_state_vector_args.py | 32 +++++++--- cirq/sim/clifford/clifford_simulator.py | 1 + 5 files changed, 107 insertions(+), 11 deletions(-) diff --git a/cirq/ops/measurement_gate_test.py b/cirq/ops/measurement_gate_test.py index 3f19c0a7534..22ece61b1ba 100644 --- a/cirq/ops/measurement_gate_test.py +++ b/cirq/ops/measurement_gate_test.py @@ -189,3 +189,83 @@ def test_op_repr(): "cirq.measure(cirq.LineQubit(0), cirq.LineQubit(1), " "key='out', " "invert_mask=(False, True))") + + +def test_act_on(): + a, b = cirq.LineQubit.range(2) + m = cirq.measure(a, b, key='out', invert_mask=(True,)) + + args = cirq.ActOnStateVectorArgs( + target_tensor=cirq.one_hot(shape=(2, 2, 2, 2, 2), dtype=np.complex64), + available_buffer=np.empty(shape=(2, 2, 2, 2, 2)), + axes=[3, 1], + prng=np.random.RandomState(), + log_of_measurement_results={}, + ) + cirq.act_on(m, args) + assert args.log_of_measurement_results == {'out': [1, 0]} + + args = cirq.ActOnStateVectorArgs( + target_tensor=cirq.one_hot(index=(0, 1, 0, 0, 0), + shape=(2, 2, 2, 2, 2), + dtype=np.complex64), + available_buffer=np.empty(shape=(2, 2, 2, 2, 2)), + axes=[3, 1], + prng=np.random.RandomState(), + log_of_measurement_results={}, + ) + cirq.act_on(m, args) + assert args.log_of_measurement_results == {'out': [1, 1]} + + args = cirq.ActOnStateVectorArgs( + target_tensor=cirq.one_hot(index=(0, 1, 0, 1, 0), + shape=(2, 2, 2, 2, 2), + dtype=np.complex64), + available_buffer=np.empty(shape=(2, 2, 2, 2, 2)), + axes=[3, 1], + prng=np.random.RandomState(), + log_of_measurement_results={}, + ) + cirq.act_on(m, args) + assert args.log_of_measurement_results == {'out': [0, 1]} + + +def test_act_on_qutrit(): + a, b = cirq.LineQid.range(2, dimension=3) + m = cirq.measure(a, b, key='out', invert_mask=(True,)) + + args = cirq.ActOnStateVectorArgs( + target_tensor=cirq.one_hot(index=(0, 2, 0, 2, 0), + shape=(3, 3, 3, 3, 3), + dtype=np.complex64), + available_buffer=np.empty(shape=(3, 3, 3, 3, 3)), + axes=[3, 1], + prng=np.random.RandomState(), + log_of_measurement_results={}, + ) + cirq.act_on(m, args) + assert args.log_of_measurement_results == {'out': [2, 2]} + + args = cirq.ActOnStateVectorArgs( + target_tensor=cirq.one_hot(index=(0, 1, 0, 2, 0), + shape=(3, 3, 3, 3, 3), + dtype=np.complex64), + available_buffer=np.empty(shape=(3, 3, 3, 3, 3)), + axes=[3, 1], + prng=np.random.RandomState(), + log_of_measurement_results={}, + ) + cirq.act_on(m, args) + assert args.log_of_measurement_results == {'out': [2, 1]} + + args = cirq.ActOnStateVectorArgs( + target_tensor=cirq.one_hot(index=(0, 2, 0, 1, 0), + shape=(3, 3, 3, 3, 3), + dtype=np.complex64), + available_buffer=np.empty(shape=(3, 3, 3, 3, 3)), + axes=[3, 1], + prng=np.random.RandomState(), + log_of_measurement_results={}, + ) + cirq.act_on(m, args) + assert args.log_of_measurement_results == {'out': [0, 2]} diff --git a/cirq/protocols/act_on_protocol.py b/cirq/protocols/act_on_protocol.py index 5b61fdcb24e..474494f7e0c 100644 --- a/cirq/protocols/act_on_protocol.py +++ b/cirq/protocols/act_on_protocol.py @@ -1,4 +1,4 @@ -# Copyright 2018 The Cirq Developers +# Copyright 2020 The Cirq Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/cirq/protocols/apply_unitary_protocol.py b/cirq/protocols/apply_unitary_protocol.py index 9513e6f7071..0ed9addfbe2 100644 --- a/cirq/protocols/apply_unitary_protocol.py +++ b/cirq/protocols/apply_unitary_protocol.py @@ -190,9 +190,6 @@ def subspace_index(self, bit of the integer is the desired bit for the first axis, and so forth in decreasing order. Can't be specified at the same time as `little_endian_bits_int`. - value_tuple: The desired value of the qids at the targeted `axes`, - packed into a tuple. Specify either `little_endian_bits_int` or - `value_tuple`. Returns: A value that can be used to index into `target_tensor` and diff --git a/cirq/sim/act_on_state_vector_args.py b/cirq/sim/act_on_state_vector_args.py index b57429d8d62..e23461d56ea 100644 --- a/cirq/sim/act_on_state_vector_args.py +++ b/cirq/sim/act_on_state_vector_args.py @@ -26,7 +26,16 @@ class ActOnStateVectorArgs: - """State and context for an operation acting on a state vector.""" + """State and context for an operation acting on a state vector. + + There are three common ways to act on this object: + + 1. Directly edit the `target_tensor` property, which is storing the state + vector of the quantum system as a numpy array with one axis per qudit. + 2. Overwrite the `available_buffer` property with the new state vector, and + then pass `available_buffer` into `swap_target_tensor_for`. + 3. Call `record_measurement_result(key, val)` to log a measurement result. + """ def __init__(self, target_tensor: np.ndarray, available_buffer: np.ndarray, axes: Iterable[int], prng: np.random.RandomState, @@ -37,9 +46,10 @@ def __init__(self, target_tensor: np.ndarray, available_buffer: np.ndarray, with one dimension for each qubit in the system. Operations are expected to perform inplace edits of this object. available_buffer: A workspace with the same shape and dtype as - `target_tensor`. The result of an operation can be put into this - buffer, instead of directly editing `target_tensor`, if - `swap_target_tensor_for` is called afterward. + `target_tensor`. Used by operations that cannot be applied to + `target_tensor` inline, in order to avoid unnecessary + allocations. Passing `available_buffer` into + `swap_target_tensor_for` will swap it for `target_tensor`. axes: The indices of axes corresponding to the qubits that the operation is supposed to act upon. prng: The pseudo random number generator to use for probabilistic @@ -94,14 +104,22 @@ def subspace_index(self, bit of the integer is the desired bit for the first axis, and so forth in increasing order. Can't be specified at the same time as `big_endian_bits_int`. + + When operating on qudits instead of qubits, the same basic logic + applies but in a different basis. For example, if the target + axes have dimension [a:2, b:3, c:2] then the integer 10 + decomposes into [a=0, b=2, c=1] via 7 = 1*(3*2) + 2*(2) + 0. + big_endian_bits_int: The desired value of the qubits at the targeted `axes`, packed into an integer. The most significant bit of the integer is the desired bit for the first axis, and so forth in decreasing order. Can't be specified at the same time as `little_endian_bits_int`. - value_tuple: The desired value of the qids at the targeted `axes`, - packed into a tuple. Specify either `little_endian_bits_int` or - `value_tuple`. + + When operating on qudits instead of qubits, the same basic logic + applies but in a different basis. For example, if the target + axes have dimension [a:2, b:3, c:2] then the integer 10 + decomposes into [a=1, b=2, c=0] via 7 = 1*(3*2) + 2*(2) + 0. Returns: A value that can be used to index into `target_tensor` and diff --git a/cirq/sim/clifford/clifford_simulator.py b/cirq/sim/clifford/clifford_simulator.py index d62656bb13f..457af3df6bb 100644 --- a/cirq/sim/clifford/clifford_simulator.py +++ b/cirq/sim/clifford/clifford_simulator.py @@ -61,6 +61,7 @@ def __init__(self, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None): @staticmethod def is_supported_operation(op: 'cirq.Operation') -> bool: """Checks whether given operation can be simulated by this simulator.""" + # TODO: support more general Pauli measurements if isinstance(op.gate, cirq.MeasurementGate): return True if isinstance(op, GlobalPhaseOperation): return True if not protocols.has_unitary(op): return False From c20031b28a07adcd369cb4942905035acf154eb3 Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Tue, 19 May 2020 17:38:03 -0700 Subject: [PATCH 5/6] Cover --- cirq/ops/common_channels_test.py | 29 ++++++++++++ cirq/ops/measurement_gate_test.py | 6 +++ cirq/protocols/act_on_protocol.py | 2 +- cirq/protocols/act_on_protocol_test.py | 33 ++++++++++++++ cirq/sim/act_on_state_vector_args.py | 5 +-- cirq/sim/act_on_state_vector_args_test.py | 54 +++++++++++++++++++++++ cirq/sim/clifford/clifford_simulator.py | 40 ++++++++--------- 7 files changed, 145 insertions(+), 24 deletions(-) create mode 100644 cirq/protocols/act_on_protocol_test.py create mode 100644 cirq/sim/act_on_state_vector_args_test.py diff --git a/cirq/ops/common_channels_test.py b/cirq/ops/common_channels_test.py index d5d01984b4f..41a06275b1a 100644 --- a/cirq/ops/common_channels_test.py +++ b/cirq/ops/common_channels_test.py @@ -314,6 +314,35 @@ def test_reset_channel_text_diagram(): cirq.ResetChannel(3)) == cirq.CircuitDiagramInfo(wire_symbols=('R',))) +def test_reset_act_on(): + with pytest.raises(TypeError, match="Failed to act"): + cirq.act_on(cirq.ResetChannel(), object()) + + args = cirq.ActOnStateVectorArgs( + target_tensor=cirq.one_hot(index=(1, 1, 1, 1, 1), + shape=(2, 2, 2, 2, 2), + dtype=np.complex64), + available_buffer=np.empty(shape=(2, 2, 2, 2, 2)), + axes=[1], + prng=np.random.RandomState(), + log_of_measurement_results={}, + ) + + cirq.act_on(cirq.ResetChannel(), args) + assert args.log_of_measurement_results == {} + np.testing.assert_allclose(args.target_tensor, + cirq.one_hot(index=(1, 0, 1, 1, 1), + shape=(2, 2, 2, 2, 2), + dtype=np.complex64)) + + cirq.act_on(cirq.ResetChannel(), args) + assert args.log_of_measurement_results == {} + np.testing.assert_allclose(args.target_tensor, + cirq.one_hot(index=(1, 0, 1, 1, 1), + shape=(2, 2, 2, 2, 2), + dtype=np.complex64)) + + def test_phase_damping_channel(): d = cirq.phase_damp(0.3) np.testing.assert_almost_equal(cirq.channel(d), diff --git a/cirq/ops/measurement_gate_test.py b/cirq/ops/measurement_gate_test.py index 22ece61b1ba..f8797d61f9f 100644 --- a/cirq/ops/measurement_gate_test.py +++ b/cirq/ops/measurement_gate_test.py @@ -195,6 +195,9 @@ def test_act_on(): a, b = cirq.LineQubit.range(2) m = cirq.measure(a, b, key='out', invert_mask=(True,)) + with pytest.raises(TypeError, match="Failed to act"): + cirq.act_on(m, object()) + args = cirq.ActOnStateVectorArgs( target_tensor=cirq.one_hot(shape=(2, 2, 2, 2, 2), dtype=np.complex64), available_buffer=np.empty(shape=(2, 2, 2, 2, 2)), @@ -229,6 +232,9 @@ def test_act_on(): cirq.act_on(m, args) assert args.log_of_measurement_results == {'out': [0, 1]} + with pytest.raises(ValueError, match="already logged to key"): + cirq.act_on(m, args) + def test_act_on_qutrit(): a, b = cirq.LineQid.range(2, dimension=3) diff --git a/cirq/protocols/act_on_protocol.py b/cirq/protocols/act_on_protocol.py index 474494f7e0c..71fec022cf1 100644 --- a/cirq/protocols/act_on_protocol.py +++ b/cirq/protocols/act_on_protocol.py @@ -103,7 +103,7 @@ def act_on( f'_act_on_ must return True or NotImplemented but got ' f'{result!r} from {action!r}._act_on_') - arg_fallback = getattr(args, '_act_on_fallback_') + arg_fallback = getattr(args, '_act_on_fallback_', None) if arg_fallback is not None: result = arg_fallback(action, allow_decompose=allow_decompose) if result is True: diff --git a/cirq/protocols/act_on_protocol_test.py b/cirq/protocols/act_on_protocol_test.py new file mode 100644 index 00000000000..073846cc30a --- /dev/null +++ b/cirq/protocols/act_on_protocol_test.py @@ -0,0 +1,33 @@ +# Copyright 2020 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A protocol that wouldn't exist if python had __rimul__.""" + +import pytest + +import cirq + + +def test_act_on_checks(): + class Bad(): + def _act_on_(self, args): + return False + + def _act_on_fallback_(self, action, allow_decompose): + return False + + with pytest.raises(ValueError, match="must return True or NotImplemented"): + _ = cirq.act_on(Bad(), object()) + + with pytest.raises(ValueError, match="must return True or NotImplemented"): + _ = cirq.act_on(object(), Bad()) diff --git a/cirq/sim/act_on_state_vector_args.py b/cirq/sim/act_on_state_vector_args.py index e23461d56ea..94ee0d773ff 100644 --- a/cirq/sim/act_on_state_vector_args.py +++ b/cirq/sim/act_on_state_vector_args.py @@ -157,7 +157,7 @@ def _act_on_fallback_(self, action: Any, allow_decompose: bool): for strat in strats: result = strat(action, self) if result is False: - break + break # coverage: ignore if result is True: return True assert result is NotImplemented, str(result) @@ -197,8 +197,7 @@ def _strat_act_on_state_vector_from_apply_decompose( def _act_all_on_state_vector(actions: Iterable[Any], qubits: Sequence['cirq.Qid'], args: 'cirq.ActOnStateVectorArgs'): - if len(qubits) != len(args.axes): - raise ValueError('len(qubits) != len(args.axes)') + assert len(qubits) == len(args.axes) qubit_map = {q: args.axes[i] for i, q in enumerate(qubits)} old_axes = args.axes diff --git a/cirq/sim/act_on_state_vector_args_test.py b/cirq/sim/act_on_state_vector_args_test.py new file mode 100644 index 00000000000..132ee5d69b6 --- /dev/null +++ b/cirq/sim/act_on_state_vector_args_test.py @@ -0,0 +1,54 @@ +# Copyright 2018 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +import cirq + + +def test_decomposed_fallback(): + class Composite(cirq.Gate): + def num_qubits(self) -> int: + return 1 + + def _decompose_(self, qubits): + yield cirq.X(*qubits) + + args = cirq.ActOnStateVectorArgs( + target_tensor=cirq.one_hot(shape=(2, 2, 2), dtype=np.complex64), + available_buffer=np.empty((2, 2, 2), dtype=np.complex64), + axes=[1], + prng=np.random.RandomState(), + log_of_measurement_results={}) + + cirq.act_on(Composite(), args) + np.testing.assert_allclose( + args.target_tensor, + cirq.one_hot(index=(0, 1, 0), shape=(2, 2, 2), dtype=np.complex64)) + + +def test_cannot_act(): + class NoDetails: + pass + + args = cirq.ActOnStateVectorArgs( + target_tensor=cirq.one_hot(shape=(2, 2, 2), dtype=np.complex64), + available_buffer=np.empty((2, 2, 2), dtype=np.complex64), + axes=[1], + prng=np.random.RandomState(), + log_of_measurement_results={}) + + with pytest.raises(TypeError, match="Failed to act"): + cirq.act_on(NoDetails(), args) diff --git a/cirq/sim/clifford/clifford_simulator.py b/cirq/sim/clifford/clifford_simulator.py index 457af3df6bb..7fbe4ec9369 100644 --- a/cirq/sim/clifford/clifford_simulator.py +++ b/cirq/sim/clifford/clifford_simulator.py @@ -97,26 +97,26 @@ def _base_iterator(self, circuit: circuits.Circuit, state=CliffordState( qubit_map, initial_state=initial_state)) - else: - state = CliffordState(qubit_map, initial_state=initial_state) - - for moment in circuit: - measurements: Dict[str, List[ - np.ndarray]] = collections.defaultdict(list) - - for op in moment: - if protocols.is_measurement(op): - if not isinstance(op.gate, ops.MeasurementGate): - raise NotImplementedError( - f'Unrecognized measurement type {op!r}') - key = protocols.measurement_key(op) - measurements[key].extend( - state.perform_measurement(op.qubits, self._prng)) - elif protocols.has_unitary(op): - state.apply_unitary(op) - - yield CliffordSimulatorStepResult(measurements=measurements, - state=state) + return + + state = CliffordState(qubit_map, initial_state=initial_state) + + for moment in circuit: + measurements: Dict[str, List[ + np.ndarray]] = collections.defaultdict(list) + + for op in moment: + if isinstance(op.gate, ops.MeasurementGate): + key = protocols.measurement_key(op) + measurements[key].extend( + state.perform_measurement(op.qubits, self._prng)) + elif protocols.has_unitary(op): + state.apply_unitary(op) + else: + raise NotImplementedError(f"Unrecognized operation: {op!r}") + + yield CliffordSimulatorStepResult(measurements=measurements, + state=state) def _simulator_iterator( self, From 8600f055f0e1413df0f7785249a3c000c5e87bde Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Tue, 19 May 2020 17:53:06 -0700 Subject: [PATCH 6/6] lint + format --- cirq/ops/common_channels_test.py | 18 ++++++++++-------- cirq/ops/measurement_gate.py | 3 +-- cirq/protocols/act_on_protocol_test.py | 2 ++ cirq/sim/act_on_state_vector_args_test.py | 3 +++ cirq/sim/clifford/clifford_simulator.py | 4 ++-- 5 files changed, 18 insertions(+), 12 deletions(-) diff --git a/cirq/ops/common_channels_test.py b/cirq/ops/common_channels_test.py index 41a06275b1a..1e2b647cb10 100644 --- a/cirq/ops/common_channels_test.py +++ b/cirq/ops/common_channels_test.py @@ -330,17 +330,19 @@ def test_reset_act_on(): cirq.act_on(cirq.ResetChannel(), args) assert args.log_of_measurement_results == {} - np.testing.assert_allclose(args.target_tensor, - cirq.one_hot(index=(1, 0, 1, 1, 1), - shape=(2, 2, 2, 2, 2), - dtype=np.complex64)) + np.testing.assert_allclose( + args.target_tensor, + cirq.one_hot(index=(1, 0, 1, 1, 1), + shape=(2, 2, 2, 2, 2), + dtype=np.complex64)) cirq.act_on(cirq.ResetChannel(), args) assert args.log_of_measurement_results == {} - np.testing.assert_allclose(args.target_tensor, - cirq.one_hot(index=(1, 0, 1, 1, 1), - shape=(2, 2, 2, 2, 2), - dtype=np.complex64)) + np.testing.assert_allclose( + args.target_tensor, + cirq.one_hot(index=(1, 0, 1, 1, 1), + shape=(2, 2, 2, 2, 2), + dtype=np.complex64)) def test_phase_damping_channel(): diff --git a/cirq/ops/measurement_gate.py b/cirq/ops/measurement_gate.py index 1b1248abce4..4044dc2bb4c 100644 --- a/cirq/ops/measurement_gate.py +++ b/cirq/ops/measurement_gate.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Iterable, Optional, Tuple, Sequence, \ - TYPE_CHECKING, Union +from typing import Any, Dict, Iterable, Optional, Tuple, Sequence, TYPE_CHECKING import numpy as np diff --git a/cirq/protocols/act_on_protocol_test.py b/cirq/protocols/act_on_protocol_test.py index 073846cc30a..2b7315e2f67 100644 --- a/cirq/protocols/act_on_protocol_test.py +++ b/cirq/protocols/act_on_protocol_test.py @@ -19,7 +19,9 @@ def test_act_on_checks(): + class Bad(): + def _act_on_(self, args): return False diff --git a/cirq/sim/act_on_state_vector_args_test.py b/cirq/sim/act_on_state_vector_args_test.py index 132ee5d69b6..878dcc7858d 100644 --- a/cirq/sim/act_on_state_vector_args_test.py +++ b/cirq/sim/act_on_state_vector_args_test.py @@ -19,7 +19,9 @@ def test_decomposed_fallback(): + class Composite(cirq.Gate): + def num_qubits(self) -> int: return 1 @@ -40,6 +42,7 @@ def _decompose_(self, qubits): def test_cannot_act(): + class NoDetails: pass diff --git a/cirq/sim/clifford/clifford_simulator.py b/cirq/sim/clifford/clifford_simulator.py index 7fbe4ec9369..4a43f5786aa 100644 --- a/cirq/sim/clifford/clifford_simulator.py +++ b/cirq/sim/clifford/clifford_simulator.py @@ -102,8 +102,8 @@ def _base_iterator(self, circuit: circuits.Circuit, state = CliffordState(qubit_map, initial_state=initial_state) for moment in circuit: - measurements: Dict[str, List[ - np.ndarray]] = collections.defaultdict(list) + measurements: Dict[str, List[np.ndarray]] = collections.defaultdict( + list) for op in moment: if isinstance(op.gate, ops.MeasurementGate):