Skip to content

Remove axes from ActOnArgs, pass qubits explicitly to act_on #4089

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Jun 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
2c5e7b2
Remove axes from ActOnArgs, pass qubits explicitly
daxfohl May 6, 2021
af048d3
Merge branch 'master' into remove_axes
daxfohl May 16, 2021
970beb3
Merge branch 'master' into remove_axes
daxfohl Jun 3, 2021
c6314d1
split protocols
daxfohl Jun 4, 2021
57564a7
require ActOnArgs to implement fallback
daxfohl Jun 4, 2021
2284a55
lint
daxfohl Jun 4, 2021
f052fb1
Split the protocols
daxfohl Jun 6, 2021
2736090
Fix tests and coverage
daxfohl Jun 6, 2021
87f2ce9
coverage
daxfohl Jun 6, 2021
97b9235
format
daxfohl Jun 6, 2021
786da8b
Merge branch 'master' into remove_axes
daxfohl Jun 6, 2021
d4ada5e
make param order consistent
daxfohl Jun 7, 2021
869ece6
format
daxfohl Jun 7, 2021
bb23492
add deprecation for axes
daxfohl Jun 13, 2021
aeb5a24
Merge branch 'master' into remove_axes
daxfohl Jun 13, 2021
d854ab5
v0.13
daxfohl Jun 13, 2021
19d86c4
lint
daxfohl Jun 13, 2021
c86bc0c
readd axes with mypy ignore
daxfohl Jun 13, 2021
837d7ee
safe
daxfohl Jun 13, 2021
2f08537
deprecate
daxfohl Jun 13, 2021
ada1822
fix args len
daxfohl Jun 13, 2021
f0a865c
tests
daxfohl Jun 13, 2021
ecba529
lint
daxfohl Jun 13, 2021
74c4a49
lint
daxfohl Jun 13, 2021
dd306ca
Merge branch 'master' into remove_axes
daxfohl Jun 14, 2021
888090d
cover
daxfohl Jun 14, 2021
b97b726
Merge remote-tracking branch 'origin/remove_axes' into remove_axes
daxfohl Jun 14, 2021
c025e17
Change _act_on_qubits_ dunder back to _act_on_
daxfohl Jun 17, 2021
8bb66e5
format
daxfohl Jun 17, 2021
95b5848
unify act_on
daxfohl Jun 17, 2021
d1c3451
lint
daxfohl Jun 17, 2021
4b933a6
Merge branch 'master' into remove_axes
daxfohl Jun 17, 2021
d4b64d1
exception
daxfohl Jun 17, 2021
992b82f
test
daxfohl Jun 17, 2021
b047323
format
daxfohl Jun 17, 2021
ff961e3
Merge branch 'master' into remove_axes
daxfohl Jun 17, 2021
84287c5
SupportsActOnQubits
daxfohl Jun 17, 2021
54daf36
Merge branch 'remove_axes' of https://github.com/daxfohl/Cirq into re…
daxfohl Jun 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@
resolve_parameters_once,
SerializableByKey,
SupportsActOn,
SupportsActOnQubits,
SupportsApplyChannel,
SupportsApplyMixture,
SupportsApproximateEquality,
Expand Down
12 changes: 9 additions & 3 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import quimb.tensor as qtn

from cirq import devices, study, ops, protocols, value
from cirq._compat import deprecated_parameter
from cirq.sim import simulator, simulator_base
from cirq.sim.act_on_args import ActOnArgs

Expand Down Expand Up @@ -224,6 +225,12 @@ def sample(
class MPSState(ActOnArgs):
"""A state of the MPS simulation."""

@deprecated_parameter(
deadline='v0.13',
fix='No longer needed. `protocols.act_on` infers axes.',
parameter_desc='axes',
match=lambda args, kwargs: 'axes' in kwargs or len(args) > 6,
)
def __init__(
self,
qubits: Sequence['cirq.Qid'],
Expand Down Expand Up @@ -451,7 +458,7 @@ def apply_op(self, op: 'cirq.Operation', prng: np.random.RandomState):
raise ValueError('Can only handle 1 and 2 qubit operations')
return True

def _act_on_fallback_(self, op: Any, allow_decompose: bool):
def _act_on_fallback_(self, op: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool):
"""Delegates the action to self.apply_op"""
return self.apply_op(op, self.prng)

Expand Down Expand Up @@ -524,7 +531,6 @@ def perform_measurement(

return results

def _perform_measurement(self) -> List[int]:
def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
"""Measures the axes specified by the simulator."""
qubits = [self.qubits[key] for key in self.axes]
return self.perform_measurement(qubits, self.prng)
3 changes: 1 addition & 2 deletions cirq-core/cirq/contrib/quimb/mps_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,7 @@ def test_state_act_on_args_initializer():
s = ccq.mps_simulator.MPSState(
qubits=(cirq.LineQubit(0),),
prng=np.random.RandomState(0),
axes=[2],
log_of_measurement_results={'test': 4},
)
assert s.axes == (2,)
assert s.qubits == (cirq.LineQubit(0),)
assert s.log_of_measurement_results == {'test': 4}
17 changes: 9 additions & 8 deletions cirq-core/cirq/ops/common_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,13 @@ def __str__(self) -> str:
return f"depolarize(p={self._p})"
return f"depolarize(p={self._p},n_qubits={self._n_qubits})"

def _act_on_(self, args: Any) -> bool:
def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']) -> bool:
from cirq.sim import clifford

if isinstance(args, clifford.ActOnCliffordTableauArgs):
if args.prng.random() < self._p:
gate = args.prng.choice([pauli_gates.X, pauli_gates.Y, pauli_gates.Z])
protocols.act_on(gate, args)
protocols.act_on(gate, args, qubits)
return True
return NotImplemented

Expand Down Expand Up @@ -720,29 +720,30 @@ def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optio
def _qid_shape_(self):
return (self._dimension,)

def _act_on_(self, args: Any):
def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']):
from cirq import sim, ops

if isinstance(args, sim.ActOnStabilizerCHFormArgs):
(axe,) = args.axes
axe = args.qubit_map[qubits[0]]
if args.state._measure(axe, args.prng):
ops.X._act_on_(args)
ops.X._act_on_(args, qubits)
return True

if isinstance(args, sim.ActOnStateVectorArgs):
# Do a silent measurement.
axes = args.get_axes(qubits)
measurements, _ = sim.measure_state_vector(
args.target_tensor,
args.axes,
axes,
out=args.target_tensor,
qid_shape=args.target_tensor.shape,
)
result = measurements[0]

# Use measurement result to zero the qid.
if result:
zero = args.subspace_index(0)
other = args.subspace_index(result)
zero = args.subspace_index(axes, 0)
other = args.subspace_index(axes, result)
args.target_tensor[zero] = args.target_tensor[other]
args.target_tensor[other] = 0

Expand Down
11 changes: 6 additions & 5 deletions cirq-core/cirq/ops/common_channels_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest

import cirq
from cirq.protocols.act_on_protocol_test import DummyActOnArgs

X = np.array([[0, 1], [1, 0]])
Y = np.array([[0, -1j], [1j, 0]])
Expand Down Expand Up @@ -478,26 +479,26 @@ def test_reset_channel_text_diagram():

def test_reset_act_on():
with pytest.raises(TypeError, match="Failed to act"):
cirq.act_on(cirq.ResetChannel(), object())
cirq.act_on(cirq.ResetChannel(), DummyActOnArgs(), qubits=())

args = cirq.ActOnStateVectorArgs(
target_tensor=cirq.one_hot(
index=(1, 1, 1, 1, 1), shape=(2, 2, 2, 2, 2), dtype=np.complex64
),
available_buffer=np.empty(shape=(2, 2, 2, 2, 2)),
axes=[1],
qubits=cirq.LineQubit.range(5),
prng=np.random.RandomState(),
log_of_measurement_results={},
)

cirq.act_on(cirq.ResetChannel(), args)
cirq.act_on(cirq.ResetChannel(), args, [cirq.LineQubit(1)])
assert args.log_of_measurement_results == {}
np.testing.assert_allclose(
args.target_tensor,
cirq.one_hot(index=(1, 0, 1, 1, 1), shape=(2, 2, 2, 2, 2), dtype=np.complex64),
)

cirq.act_on(cirq.ResetChannel(), args)
cirq.act_on(cirq.ResetChannel(), args, [cirq.LineQubit(1)])
assert args.log_of_measurement_results == {}
np.testing.assert_allclose(
args.target_tensor,
Expand Down Expand Up @@ -693,7 +694,7 @@ def test_bit_flip_channel_text_diagram():
def test_stabilizer_supports_depolarize():
with pytest.raises(TypeError, match="act_on"):
for _ in range(100):
cirq.act_on(cirq.depolarize(3 / 4), object())
cirq.act_on(cirq.depolarize(3 / 4), DummyActOnArgs(), qubits=())

q = cirq.LineQubit(0)
c = cirq.Circuit(cirq.depolarize(3 / 4).on(q), cirq.measure(q, key='m'))
Expand Down
52 changes: 26 additions & 26 deletions cirq-core/cirq/ops/common_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@
"""


def _act_with_gates(args, *gates: 'cirq.SupportsActOn') -> None:
def _act_with_gates(args, qubits, *gates: 'cirq.SupportsActOnQubits') -> None:
"""Act on the given args with the given gates in order."""
for gate in gates:
assert gate._act_on_(args)
assert gate._act_on_(args, qubits)


def _pi(rads):
Expand Down Expand Up @@ -108,14 +108,14 @@ def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> Optional[np.nda
args.available_buffer *= p
return args.available_buffer

def _act_on_(self, args: Any):
def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']):
from cirq.sim import clifford

if isinstance(args, clifford.ActOnCliffordTableauArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
tableau = args.tableau
q = args.axes[0]
q = args.qubit_map[qubits[0]]
effective_exponent = self._exponent % 2
if effective_exponent == 0.5:
tableau.xs[:, q] ^= tableau.zs[:, q]
Expand All @@ -130,7 +130,7 @@ def _act_on_(self, args: Any):
if isinstance(args, clifford.ActOnStabilizerCHFormArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
_act_with_gates(args, H, ZPowGate(exponent=self._exponent), H)
_act_with_gates(args, qubits, H, ZPowGate(exponent=self._exponent), H)
# Adjust the global phase based on the global_shift parameter.
args.state.omega *= np.exp(1j * np.pi * self.global_shift * self.exponent)
return True
Expand Down Expand Up @@ -360,14 +360,14 @@ def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> Optional[np.nda
args.available_buffer *= p
return args.available_buffer

def _act_on_(self, args: Any):
def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']):
from cirq.sim import clifford

if isinstance(args, clifford.ActOnCliffordTableauArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
tableau = args.tableau
q = args.axes[0]
q = args.qubit_map[qubits[0]]
effective_exponent = self._exponent % 2
if effective_exponent == 0.5:
tableau.rs[:] ^= tableau.xs[:, q] & (~tableau.zs[:, q])
Expand All @@ -392,13 +392,13 @@ def _act_on_(self, args: Any):
state = args.state
Z = ZPowGate()
if effective_exponent == 0.5:
_act_with_gates(args, Z, H)
_act_with_gates(args, qubits, Z, H)
state.omega *= (1 + 1j) / (2 ** 0.5)
elif effective_exponent == 1:
_act_with_gates(args, Z, H, Z, H)
_act_with_gates(args, qubits, Z, H, Z, H)
state.omega *= 1j
elif effective_exponent == 1.5:
_act_with_gates(args, H, Z)
_act_with_gates(args, qubits, H, Z)
state.omega *= (1 - 1j) / (2 ** 0.5)
# Adjust the global phase based on the global_shift parameter.
args.state.omega *= np.exp(1j * np.pi * self.global_shift * self.exponent)
Expand Down Expand Up @@ -579,14 +579,14 @@ def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> Optional[np.nda
args.target_tensor *= p
return args.target_tensor

def _act_on_(self, args: Any):
def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']):
from cirq.sim import clifford

if isinstance(args, clifford.ActOnCliffordTableauArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
tableau = args.tableau
q = args.axes[0]
q = args.qubit_map[qubits[0]]
effective_exponent = self._exponent % 2
if effective_exponent == 0.5:
tableau.rs[:] ^= tableau.xs[:, q] & tableau.zs[:, q]
Expand All @@ -601,7 +601,7 @@ def _act_on_(self, args: Any):
if isinstance(args, clifford.ActOnStabilizerCHFormArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
q = args.axes[0]
q = args.qubit_map[qubits[0]]
effective_exponent = self._exponent % 2
state = args.state
for _ in range(int(effective_exponent * 2)):
Expand Down Expand Up @@ -896,14 +896,14 @@ def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> Optional[np.nda
args.target_tensor *= np.sqrt(2) * p
return args.target_tensor

def _act_on_(self, args: Any):
def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']):
from cirq.sim import clifford

if isinstance(args, clifford.ActOnCliffordTableauArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
tableau = args.tableau
q = args.axes[0]
q = args.qubit_map[qubits[0]]
if self._exponent % 2 == 1:
(tableau.xs[:, q], tableau.zs[:, q]) = (
tableau.zs[:, q].copy(),
Expand All @@ -915,7 +915,7 @@ def _act_on_(self, args: Any):
if isinstance(args, clifford.ActOnStabilizerCHFormArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
q = args.axes[0]
q = args.qubit_map[qubits[0]]
state = args.state
if self._exponent % 2 == 1:
# Prescription for H left multiplication
Expand Down Expand Up @@ -1059,15 +1059,15 @@ def _apply_unitary_(
args.target_tensor *= p
return args.target_tensor

def _act_on_(self, args: Any):
def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']):
from cirq.sim import clifford

if isinstance(args, clifford.ActOnCliffordTableauArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
tableau = args.tableau
q1 = args.axes[0]
q2 = args.axes[1]
q1 = args.qubit_map[qubits[0]]
q2 = args.qubit_map[qubits[1]]
if self._exponent % 2 == 1:
(tableau.xs[:, q2], tableau.zs[:, q2]) = (
tableau.zs[:, q2].copy(),
Expand All @@ -1088,8 +1088,8 @@ def _act_on_(self, args: Any):
if isinstance(args, clifford.ActOnStabilizerCHFormArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
q1 = args.axes[0]
q2 = args.axes[1]
q1 = args.qubit_map[qubits[0]]
q2 = args.qubit_map[qubits[1]]
state = args.state
if self._exponent % 2 == 1:
# Prescription for CZ left multiplication.
Expand Down Expand Up @@ -1282,15 +1282,15 @@ def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> Optional[np.nda
args.target_tensor *= p
return args.target_tensor

def _act_on_(self, args: Any):
def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']):
from cirq.sim import clifford

if isinstance(args, clifford.ActOnCliffordTableauArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
tableau = args.tableau
q1 = args.axes[0]
q2 = args.axes[1]
q1 = args.qubit_map[qubits[0]]
q2 = args.qubit_map[qubits[1]]
if self._exponent % 2 == 1:
tableau.rs[:] ^= (
tableau.xs[:, q1]
Expand All @@ -1304,8 +1304,8 @@ def _act_on_(self, args: Any):
if isinstance(args, clifford.ActOnStabilizerCHFormArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
q1 = args.axes[0]
q2 = args.axes[1]
q1 = args.qubit_map[qubits[0]]
q2 = args.qubit_map[qubits[1]]
state = args.state
if self._exponent % 2 == 1:
# Prescription for CX left multiplication.
Expand Down
Loading