Skip to content

Move decompose-and-act-on strategy to act_on file and reuse it in all ActOn*Args #3874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cirq/protocols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from cirq.protocols.act_on_protocol import (
act_on,
strat_act_on_from_apply_decompose,
SupportsActOn,
)
from cirq.protocols.apply_unitary_protocol import (
Expand Down
41 changes: 41 additions & 0 deletions cirq/protocols/act_on_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

from cirq._doc import doc_private
from cirq.type_workarounds import NotImplementedType
from cirq.protocols.decompose_protocol import (
_try_decompose_into_operations_and_qubits,
)

if TYPE_CHECKING:
pass
Expand Down Expand Up @@ -123,3 +126,41 @@ def act_on(
f"Action type: {type(action)}\n"
f"Action repr: {action!r}\n"
)


def strat_act_on_from_apply_decompose(
action: Any,
args: Any,
) -> bool:
"""Decomposes the input action and then tries to apply `act_on` to each constituent operation.

Args:
action: The action to decompose and apply to the arg. Typically a `cirq.Operation`.
args: A mutable state object that should be modified by the action. If the internal state
is mutated and the method eventually fails, the internal state is restored to its original
value.

Returns:
True if the action was able to fully act on the args. Returns NotImplemented otherwise.
"""
operations, qubits, _ = _try_decompose_into_operations_and_qubits(action)

if operations is None:
return NotImplemented
assert len(qubits) == len(args.axes)
qubit_map = {q: args.axes[i] for i, q in enumerate(qubits)}

old_axes = args.axes
old_internal_state = args.internal_state.copy()
try:
for action in operations:
args.axes = tuple(qubit_map[q] for q in action.qubits)
act_on(action, args)
except TypeError:
# Restore original state in case of failure since the try block above might have modified
# it, but the modification is now useless since the whole loop was unsuccessful.
args.internal_state = old_internal_state
return NotImplemented
finally:
args.axes = old_axes
return True
41 changes: 10 additions & 31 deletions cirq/sim/act_on_state_vector_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,11 @@
# limitations under the License.
"""Objects and methods for acting efficiently on a state vector."""

from typing import Any, Iterable, Sequence, Tuple, TYPE_CHECKING, Union, Dict
from typing import Any, Iterable, Tuple, TYPE_CHECKING, Union, Dict

import numpy as np

from cirq import linalg, protocols
from cirq.protocols.decompose_protocol import (
_try_decompose_into_operations_and_qubits,
)

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -84,6 +81,14 @@ def swap_target_tensor_for(self, new_target_tensor: np.ndarray):
self.available_buffer = self.target_tensor
self.target_tensor = new_target_tensor

@property
def internal_state(self):
return self.target_tensor

@internal_state.setter
def internal_state(self, new_value):
self.swap_target_tensor_for(new_value)

def record_measurement_result(self, key: str, value: Any):
"""Adds a measurement result to the log.

Expand Down Expand Up @@ -157,7 +162,7 @@ def _act_on_fallback_(self, action: Any, allow_decompose: bool):
_strat_act_on_state_vector_from_channel,
]
if allow_decompose:
strats.append(_strat_act_on_state_vector_from_apply_decompose)
strats.append(protocols.strat_act_on_from_apply_decompose) # type: ignore

# Try each strategy, stopping if one works.
for strat in strats:
Expand Down Expand Up @@ -191,32 +196,6 @@ def _strat_act_on_state_vector_from_apply_unitary(
return True


def _strat_act_on_state_vector_from_apply_decompose(
val: Any,
args: ActOnStateVectorArgs,
) -> bool:
operations, qubits, _ = _try_decompose_into_operations_and_qubits(val)
if operations is None:
return NotImplemented
return _act_all_on_state_vector(operations, qubits, args)


def _act_all_on_state_vector(
actions: Iterable[Any], qubits: Sequence['cirq.Qid'], args: 'cirq.ActOnStateVectorArgs'
):
assert len(qubits) == len(args.axes)
qubit_map = {q: args.axes[i] for i, q in enumerate(qubits)}

old_axes = args.axes
try:
for action in actions:
args.axes = tuple(qubit_map[q] for q in action.qubits)
protocols.act_on(action, args)
finally:
args.axes = old_axes
return True


def _strat_act_on_state_vector_from_mixture(action: Any, args: 'cirq.ActOnStateVectorArgs') -> bool:
mixture = protocols.mixture(action, default=None)
if mixture is None:
Expand Down
26 changes: 26 additions & 0 deletions cirq/sim/act_on_state_vector_args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,32 @@ def _decompose_(self, qubits):
)


def test_unsupported_decomposed_preserves_args():
class NoDetails(cirq.SingleQubitGate):
pass

class UnsupportedComposite(cirq.Gate):
def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
return [cirq.H(*qubits), NoDetails().on(*qubits), cirq.H(*qubits)]

original_tensor = cirq.one_hot(shape=(2, 2, 2), dtype=np.complex64)

args = cirq.ActOnStateVectorArgs(
target_tensor=original_tensor.copy(),
available_buffer=np.empty((2, 2, 2), dtype=np.complex64),
axes=[1],
prng=np.random.RandomState(),
log_of_measurement_results={},
)
with pytest.raises(TypeError, match="Failed to act"):
cirq.act_on(UnsupportedComposite(), args)

np.testing.assert_allclose(args.target_tensor, original_tensor)


def test_cannot_act():
class NoDetails:
pass
Expand Down
19 changes: 14 additions & 5 deletions cirq/sim/clifford/act_on_clifford_tableau_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from cirq.ops import common_gates
from cirq.ops import pauli_gates
from cirq.ops.clifford_gate import SingleQubitCliffordGate
from cirq.protocols import has_unitary, num_qubits, unitary
from cirq import protocols
from cirq.sim.clifford.clifford_tableau import CliffordTableau

if TYPE_CHECKING:
Expand Down Expand Up @@ -60,6 +60,14 @@ def __init__(
self.prng = prng
self.log_of_measurement_results = log_of_measurement_results

@property
def internal_state(self):
return self.tableau

@internal_state.setter
def internal_state(self, new_value):
self.tableau = new_value

def record_measurement_result(self, key: str, value: Any):
"""Adds a measurement result to the log.
Args:
Expand All @@ -75,7 +83,8 @@ def record_measurement_result(self, key: str, value: Any):
def _act_on_fallback_(self, action: Any, allow_decompose: bool):
strats = []
if allow_decompose:
strats.append(_strat_act_on_clifford_tableau_from_single_qubit_decompose)
strats.append(protocols.strat_act_on_from_apply_decompose)
strats.append(_strat_act_on_clifford_tableau_from_single_qubit_decompose) # type: ignore
for strat in strats:
result = strat(action, self)
if result is False:
Expand All @@ -90,10 +99,10 @@ def _act_on_fallback_(self, action: Any, allow_decompose: bool):
def _strat_act_on_clifford_tableau_from_single_qubit_decompose(
val: Any, args: 'cirq.ActOnCliffordTableauArgs'
) -> bool:
if num_qubits(val) == 1:
if not has_unitary(val):
if protocols.num_qubits(val) == 1:
if not protocols.has_unitary(val):
return NotImplemented
u = unitary(val)
u = protocols.unitary(val)
clifford_gate = SingleQubitCliffordGate.from_unitary(u)
if clifford_gate is not None:
for axis, quarter_turns in clifford_gate.decompose_rotation():
Expand Down
52 changes: 52 additions & 0 deletions cirq/sim/clifford/act_on_clifford_tableau_args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,58 @@ def _unitary_(self):
assert args.tableau == expected_args.tableau


def test_decomposed_fallback():
class Composite(cirq.Gate):
def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
yield cirq.X(*qubits)

original_tableau = cirq.CliffordTableau(num_qubits=3)

args = cirq.ActOnCliffordTableauArgs(
tableau=original_tableau.copy(),
axes=[1],
prng=np.random.RandomState(),
log_of_measurement_results={},
)
cirq.act_on(Composite(), args)
expected_args = cirq.ActOnCliffordTableauArgs(
tableau=original_tableau.copy(),
axes=[1],
prng=np.random.RandomState(),
log_of_measurement_results={},
)
cirq.act_on(cirq.X, expected_args)
assert args.tableau == expected_args.tableau


def test_unsupported_decomposed_preserves_args():
class NoDetails(cirq.SingleQubitGate):
pass

class UnsupportedComposite(cirq.Gate):
def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
return [cirq.H(*qubits), NoDetails().on(*qubits), cirq.H(*qubits)]

original_tableau = cirq.CliffordTableau(num_qubits=3)

args = cirq.ActOnCliffordTableauArgs(
tableau=original_tableau.copy(),
axes=[1],
prng=np.random.RandomState(),
log_of_measurement_results={},
)
with pytest.raises(TypeError, match="Failed to act"):
cirq.act_on(UnsupportedComposite(), args)

assert args.tableau == original_tableau


def test_cannot_act():
class NoDetails:
pass
Expand Down
21 changes: 15 additions & 6 deletions cirq/sim/clifford/act_on_stabilizer_ch_form_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from cirq.ops import common_gates, pauli_gates
from cirq.ops.clifford_gate import SingleQubitCliffordGate
from cirq.protocols import has_unitary, num_qubits, unitary
from cirq import protocols
from cirq.sim.clifford.stabilizer_state_ch_form import StabilizerStateChForm

if TYPE_CHECKING:
Expand Down Expand Up @@ -57,6 +57,14 @@ def __init__(
self.prng = prng
self.log_of_measurement_results = log_of_measurement_results

@property
def internal_state(self):
return self.state

@internal_state.setter
def internal_state(self, new_value):
self.state = new_value

def record_measurement_result(self, key: str, value: Any):
"""Adds a measurement result to the log.
Args:
Expand All @@ -72,7 +80,8 @@ def record_measurement_result(self, key: str, value: Any):
def _act_on_fallback_(self, action: Any, allow_decompose: bool):
strats = []
if allow_decompose:
strats.append(_strat_act_on_stabilizer_ch_form_from_single_qubit_decompose)
strats.append(protocols.strat_act_on_from_apply_decompose)
strats.append(_strat_act_on_stabilizer_ch_form_from_single_qubit_decompose) # type: ignore
for strat in strats:
result = strat(action, self)
if result is True:
Expand All @@ -85,10 +94,10 @@ def _act_on_fallback_(self, action: Any, allow_decompose: bool):
def _strat_act_on_stabilizer_ch_form_from_single_qubit_decompose(
val: Any, args: 'cirq.ActOnStabilizerCHFormArgs'
) -> bool:
if num_qubits(val) == 1:
if not has_unitary(val):
if protocols.num_qubits(val) == 1:
if not protocols.has_unitary(val):
return NotImplemented
u = unitary(val)
u = protocols.unitary(val)
clifford_gate = SingleQubitCliffordGate.from_unitary(u)
if clifford_gate is not None:
# Gather the effective unitary applied so as to correct for the
Expand All @@ -107,7 +116,7 @@ def _strat_act_on_stabilizer_ch_form_from_single_qubit_decompose(
gate = common_gates.ZPowGate(exponent=quarter_turns / 2)
assert gate._act_on_(args)

final_unitary = np.matmul(unitary(gate), final_unitary)
final_unitary = np.matmul(protocols.unitary(gate), final_unitary)

# Find the entry with the largest magnitude in the input unitary.
k = max(np.ndindex(*u.shape), key=lambda t: abs(u[t]))
Expand Down
52 changes: 52 additions & 0 deletions cirq/sim/clifford/act_on_stabilizer_ch_form_args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,55 @@ def _unitary_(self):
)
cirq.act_on(cirq.H, expected_args)
np.testing.assert_allclose(args.state.state_vector(), expected_args.state.state_vector())


def test_decomposed_fallback():
class Composite(cirq.Gate):
def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
yield cirq.X(*qubits)

original_state = cirq.StabilizerStateChForm(num_qubits=3)

args = cirq.ActOnStabilizerCHFormArgs(
state=original_state.copy(),
axes=[1],
prng=np.random.RandomState(),
log_of_measurement_results={},
)
cirq.act_on(Composite(), args)
expected_args = cirq.ActOnStabilizerCHFormArgs(
state=original_state.copy(),
axes=[1],
prng=np.random.RandomState(),
log_of_measurement_results={},
)
cirq.act_on(cirq.X, expected_args)
np.testing.assert_allclose(args.state.state_vector(), expected_args.state.state_vector())


def test_unsupported_decomposed_preserves_args():
class NoDetails(cirq.SingleQubitGate):
pass

class UnsupportedComposite(cirq.Gate):
def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
return [cirq.H(*qubits), NoDetails().on(*qubits), cirq.H(*qubits)]

original_state = cirq.StabilizerStateChForm(num_qubits=3)

args = cirq.ActOnStabilizerCHFormArgs(
state=original_state.copy(),
axes=[1],
prng=np.random.RandomState(),
log_of_measurement_results={},
)
with pytest.raises(TypeError, match="Failed to act"):
cirq.act_on(UnsupportedComposite(), args)

np.testing.assert_allclose(args.state.state_vector(), original_state.state_vector())
Loading