Skip to content

Commit e4c191d

Browse files
authored
Support ActOnStabilizerCHFormArgs in the common gates and testing (#3203)
* Add ActOnStabilizerCHFormArgs and related logic * Some comment updates * Fix _update_sum call * Update parameterized condition * Undo already split out change to stabilizerstatechfrom * revert split out change * Merge and clean * Improve error handling and documentation * Add comments with reference to the relevant sections of the paper * Fix the year in copyright * Use H instead of HPowGate() * Address comments * Hide traceback and correct copyright year
1 parent e1754ec commit e4c191d

10 files changed

+477
-78
lines changed

cirq/ops/common_gates.py

+115-3
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,17 @@ def _act_on_(self, args: Any):
103103
tableau.xs[:, q] ^= tableau.zs[:, q]
104104
return True
105105

106+
if isinstance(args, clifford.ActOnStabilizerCHFormArgs):
107+
if protocols.is_parameterized(self) or self.exponent % 0.5 != 0:
108+
return NotImplemented
109+
assert all(
110+
gate._act_on_(args) for gate in # type: ignore
111+
[H, ZPowGate(exponent=self._exponent), H])
112+
# Adjust the global phase based on the global_shift parameter.
113+
args.state.omega *= np.exp(1j * np.pi * self.global_shift *
114+
self.exponent)
115+
return True
116+
106117
return NotImplemented
107118

108119
def in_su2(self) -> 'XPowGate':
@@ -322,6 +333,30 @@ def _act_on_(self, args: Any):
322333
tableau.xs[:, q].copy())
323334
return True
324335

336+
if isinstance(args, clifford.ActOnStabilizerCHFormArgs):
337+
if protocols.is_parameterized(self) or self.exponent % 0.5 != 0:
338+
return NotImplemented
339+
effective_exponent = self._exponent % 2
340+
state = args.state
341+
if effective_exponent == 0.5:
342+
assert all(
343+
gate._act_on_(args) # type: ignore
344+
for gate in [ZPowGate(), H])
345+
state.omega *= (1 + 1j) / (2**0.5) # type: ignore
346+
elif effective_exponent == 1:
347+
assert all(
348+
gate._act_on_(args) for gate in # type: ignore
349+
[ZPowGate(), H, ZPowGate(), H])
350+
state.omega *= 1j # type: ignore
351+
elif effective_exponent == 1.5:
352+
assert all(
353+
gate._act_on_(args) # type: ignore
354+
for gate in [H, ZPowGate()])
355+
state.omega *= (1 - 1j) / (2**0.5) # type: ignore
356+
# Adjust the global phase based on the global_shift parameter.
357+
args.state.omega *= np.exp(1j * np.pi * self.global_shift *
358+
self.exponent)
359+
return True
325360
return NotImplemented
326361

327362
def in_su2(self) -> 'YPowGate':
@@ -490,6 +525,22 @@ def _act_on_(self, args: Any):
490525
tableau.zs[:, q] ^= tableau.xs[:, q]
491526
return True
492527

528+
if isinstance(args, clifford.ActOnStabilizerCHFormArgs):
529+
if protocols.is_parameterized(self) or self.exponent % 0.5 != 0:
530+
return NotImplemented
531+
q = args.axes[0]
532+
effective_exponent = self._exponent % 2
533+
state = args.state
534+
for _ in range(int(effective_exponent * 2)):
535+
# Prescription for S left multiplication.
536+
# Reference: https://arxiv.org/abs/1808.00128 Proposition 4 end
537+
state.M[q, :] ^= state.G[q, :]
538+
state.gamma[q] = (state.gamma[q] - 1) % 4
539+
# Adjust the global phase based on the global_shift parameter.
540+
args.state.omega *= np.exp(1j * np.pi * self.global_shift *
541+
self.exponent)
542+
return True
543+
493544
return NotImplemented
494545

495546
def _decompose_into_clifford_with_qubits_(self, qubits):
@@ -756,18 +807,43 @@ def _act_on_(self, args: Any):
756807
from cirq.sim import clifford
757808

758809
if isinstance(args, clifford.ActOnCliffordTableauArgs):
759-
if protocols.is_parameterized(self) or self.exponent % 0.5 != 0:
810+
if protocols.is_parameterized(self) or self.exponent % 1 != 0:
760811
return NotImplemented
761812
tableau = args.tableau
762813
q = args.axes[0]
763-
if self._exponent % 1 != 0:
764-
return NotImplemented
765814
if self._exponent % 2 == 1:
766815
(tableau.xs[:, q], tableau.zs[:, q]) = (tableau.zs[:, q].copy(),
767816
tableau.xs[:, q].copy())
768817
tableau.rs[:] ^= (tableau.xs[:, q] & tableau.zs[:, q])
769818
return True
770819

820+
if isinstance(args, clifford.ActOnStabilizerCHFormArgs):
821+
if protocols.is_parameterized(self) or self.exponent % 1 != 0:
822+
return NotImplemented
823+
q = args.axes[0]
824+
state = args.state
825+
if self._exponent % 2 == 1:
826+
# Prescription for H left multiplication
827+
# Reference: https://arxiv.org/abs/1808.00128
828+
# Equations 48, 49 and Proposition 4
829+
t = state.s ^ (state.G[q, :] & state.v)
830+
u = state.s ^ (state.F[q, :] &
831+
(~state.v)) ^ (state.M[q, :] & state.v)
832+
833+
alpha = sum(state.G[q, :] & (~state.v) & state.s) % 2
834+
beta = sum(state.M[q, :] & (~state.v) & state.s)
835+
beta += sum(state.F[q, :] & state.v & state.M[q, :])
836+
beta += sum(state.F[q, :] & state.v & state.s)
837+
beta %= 2
838+
839+
delta = (state.gamma[q] + 2 * (alpha + beta)) % 4
840+
841+
state.update_sum(t, u, delta=delta, alpha=alpha)
842+
# Adjust the global phase based on the global_shift parameter.
843+
args.state.omega *= np.exp(1j * np.pi * self.global_shift *
844+
self.exponent)
845+
return True
846+
771847
return NotImplemented
772848

773849
def _decompose_(self, qubits):
@@ -900,6 +976,22 @@ def _act_on_(self, args: Any):
900976
tableau.rs[:] ^= (tableau.xs[:, q2] & tableau.zs[:, q2])
901977
return True
902978

979+
if isinstance(args, clifford.ActOnStabilizerCHFormArgs):
980+
if protocols.is_parameterized(self) or self.exponent % 1 != 0:
981+
return NotImplemented
982+
q1 = args.axes[0]
983+
q2 = args.axes[1]
984+
state = args.state
985+
if self._exponent % 2 == 1:
986+
# Prescription for CZ left multiplication.
987+
# Reference: https://arxiv.org/abs/1808.00128 Proposition 4 end
988+
state.M[q1, :] ^= state.G[q2, :]
989+
state.M[q2, :] ^= state.G[q1, :]
990+
# Adjust the global phase based on the global_shift parameter.
991+
args.state.omega *= np.exp(1j * np.pi * self.global_shift *
992+
self.exponent)
993+
return True
994+
903995
return NotImplemented
904996

905997
def _pauli_expansion_(self) -> value.LinearDict[str]:
@@ -1098,6 +1190,26 @@ def _act_on_(self, args: Any):
10981190
tableau.zs[:, q1] ^= tableau.zs[:, q2]
10991191
return True
11001192

1193+
if isinstance(args, clifford.ActOnStabilizerCHFormArgs):
1194+
if protocols.is_parameterized(self) or self.exponent % 1 != 0:
1195+
return NotImplemented
1196+
q1 = args.axes[0]
1197+
q2 = args.axes[1]
1198+
state = args.state
1199+
if self._exponent % 2 == 1:
1200+
# Prescription for CX left multiplication.
1201+
# Reference: https://arxiv.org/abs/1808.00128 Proposition 4 end
1202+
state.gamma[q1] = (
1203+
state.gamma[q1] + state.gamma[q2] + 2 *
1204+
(sum(state.M[q1, :] & state.F[q2, :]) % 2)) % 4
1205+
state.G[q2, :] ^= state.G[q1, :]
1206+
state.F[q1, :] ^= state.F[q2, :]
1207+
state.M[q1, :] ^= state.M[q2, :]
1208+
# Adjust the global phase based on the global_shift parameter.
1209+
args.state.omega *= np.exp(1j * np.pi * self.global_shift *
1210+
self.exponent)
1211+
return True
1212+
11011213
return NotImplemented
11021214

11031215
def _pauli_expansion_(self) -> value.LinearDict[str]:

cirq/ops/common_gates_test.py

+113-42
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def test_h_str():
290290
assert str(cirq.H**0.5) == 'H^0.5'
291291

292292

293-
def test_x_act_on():
293+
def test_x_act_on_tableau():
294294
with pytest.raises(TypeError, match="Failed to act"):
295295
cirq.act_on(cirq.X, object())
296296
original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31)
@@ -326,14 +326,21 @@ def test_x_act_on():
326326
cirq.act_on(cirq.X**foo, args)
327327

328328

329-
class PhaserGate(cirq.SingleQubitGate):
329+
class iZGate(cirq.SingleQubitGate):
330330
"""Equivalent to an iZ gate without _act_on_ defined on it."""
331331

332332
def _unitary_(self):
333333
return np.array([[1j, 0], [0, -1j]])
334334

335335

336-
def test_y_act_on():
336+
class MinusOnePhaseGate(cirq.SingleQubitGate):
337+
"""Equivalent to a -1 global phase without _act_on_ defined on it."""
338+
339+
def _unitary_(self):
340+
return np.array([[-1, 0], [0, -1]])
341+
342+
343+
def test_y_act_on_tableau():
337344
with pytest.raises(TypeError, match="Failed to act"):
338345
cirq.act_on(cirq.Y, object())
339346
original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31)
@@ -348,18 +355,18 @@ def test_y_act_on():
348355

349356
cirq.act_on(cirq.Y**0.5, args, allow_decompose=False)
350357
cirq.act_on(cirq.Y**0.5, args, allow_decompose=False)
351-
cirq.act_on(PhaserGate(), args)
358+
cirq.act_on(iZGate(), args)
352359
assert args.log_of_measurement_results == {}
353360
assert args.tableau == flipped_tableau
354361

355362
cirq.act_on(cirq.Y, args, allow_decompose=False)
356-
cirq.act_on(PhaserGate(), args, allow_decompose=True)
363+
cirq.act_on(iZGate(), args, allow_decompose=True)
357364
assert args.log_of_measurement_results == {}
358365
assert args.tableau == original_tableau
359366

360367
cirq.act_on(cirq.Y**3.5, args, allow_decompose=False)
361368
cirq.act_on(cirq.Y**3.5, args, allow_decompose=False)
362-
cirq.act_on(PhaserGate(), args)
369+
cirq.act_on(iZGate(), args)
363370
assert args.log_of_measurement_results == {}
364371
assert args.tableau == flipped_tableau
365372

@@ -372,9 +379,11 @@ def test_y_act_on():
372379
cirq.act_on(cirq.Y**foo, args)
373380

374381

375-
def test_z_h_act_on():
382+
def test_z_h_act_on_tableau():
376383
with pytest.raises(TypeError, match="Failed to act"):
377-
cirq.act_on(cirq.Y, object())
384+
cirq.act_on(cirq.Z, object())
385+
with pytest.raises(TypeError, match="Failed to act"):
386+
cirq.act_on(cirq.H, object())
378387
original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31)
379388
flipped_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=23)
380389

@@ -417,18 +426,16 @@ def test_z_h_act_on():
417426
with pytest.raises(TypeError, match="Failed to act action on state"):
418427
cirq.act_on(cirq.Z**foo, args)
419428

420-
foo = sympy.Symbol('foo')
421429
with pytest.raises(TypeError, match="Failed to act action on state"):
422430
cirq.act_on(cirq.H**foo, args)
423431

424-
foo = sympy.Symbol('foo')
425432
with pytest.raises(TypeError, match="Failed to act action on state"):
426433
cirq.act_on(cirq.H**1.5, args)
427434

428435

429-
def test_cx_act_on():
436+
def test_cx_act_on_tableau():
430437
with pytest.raises(TypeError, match="Failed to act"):
431-
cirq.act_on(cirq.Y, object())
438+
cirq.act_on(cirq.CX, object())
432439
original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31)
433440

434441
args = cirq.ActOnCliffordTableauArgs(
@@ -471,7 +478,7 @@ def test_cx_act_on():
471478
cirq.act_on(cirq.CX**1.5, args)
472479

473480

474-
def test_cz_act_on():
481+
def test_cz_act_on_tableau():
475482
with pytest.raises(TypeError, match="Failed to act"):
476483
cirq.act_on(cirq.Y, object())
477484
original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31)
@@ -516,38 +523,102 @@ def test_cz_act_on():
516523
cirq.act_on(cirq.CZ**1.5, args)
517524

518525

526+
foo = sympy.Symbol('foo')
527+
528+
529+
@pytest.mark.parametrize('input_gate_sequence, outcome', [
530+
([cirq.X**foo], 'Error'),
531+
([cirq.X**0.25], 'Error'),
532+
([cirq.X**4], 'Original'),
533+
([cirq.X**0.5, cirq.X**0.5], 'Flipped'),
534+
([cirq.X], 'Flipped'),
535+
([cirq.X**3.5, cirq.X**3.5], 'Flipped'),
536+
([cirq.Y**foo], 'Error'),
537+
([cirq.Y**0.25], 'Error'),
538+
([cirq.Y**4], 'Original'),
539+
([cirq.Y**0.5, cirq.Y**0.5, iZGate()], 'Flipped'),
540+
([cirq.Y, iZGate()], 'Flipped'),
541+
([cirq.Y**3.5, cirq.Y**3.5, iZGate()], 'Flipped'),
542+
([cirq.Z**foo], 'Error'),
543+
([cirq.H**foo], 'Error'),
544+
([cirq.H**1.5], 'Error'),
545+
([cirq.Z**4], 'Original'),
546+
([cirq.H**4], 'Original'),
547+
([cirq.H, cirq.S, cirq.S, cirq.H], 'Flipped'),
548+
([cirq.H, cirq.Z, cirq.H], 'Flipped'),
549+
([cirq.H, cirq.Z**3.5, cirq.Z**3.5, cirq.H], 'Flipped'),
550+
([cirq.CX**foo], 'Error'),
551+
([cirq.CX**1.5], 'Error'),
552+
([cirq.CX**4], 'Original'),
553+
([cirq.CX], 'Flipped'),
554+
([cirq.CZ**foo], 'Error'),
555+
([cirq.CZ**1.5], 'Error'),
556+
([cirq.CZ**4], 'Original'),
557+
([cirq.CZ, MinusOnePhaseGate()], 'Original'),
558+
])
559+
def test_act_on_ch_form(input_gate_sequence, outcome):
560+
original_state = cirq.StabilizerStateChForm(num_qubits=5, initial_state=31)
561+
num_qubits = cirq.num_qubits(input_gate_sequence[0])
562+
if num_qubits == 1:
563+
axes = [1]
564+
else:
565+
assert num_qubits == 2
566+
axes = [0, 1]
567+
args = cirq.ActOnStabilizerCHFormArgs(state=original_state.copy(),
568+
axes=axes)
569+
570+
flipped_state = cirq.StabilizerStateChForm(num_qubits=5, initial_state=23)
571+
572+
if outcome == 'Error':
573+
with pytest.raises(TypeError, match="Failed to act action on state"):
574+
for input_gate in input_gate_sequence:
575+
cirq.act_on(input_gate, args)
576+
return
577+
578+
for input_gate in input_gate_sequence:
579+
cirq.act_on(input_gate, args)
580+
581+
if outcome == 'Original':
582+
np.testing.assert_allclose(args.state.state_vector(),
583+
original_state.state_vector())
584+
585+
if outcome == 'Flipped':
586+
np.testing.assert_allclose(args.state.state_vector(),
587+
flipped_state.state_vector())
588+
589+
519590
@pytest.mark.parametrize(
520-
'input_gate',
591+
'input_gate, assert_implemented',
521592
[
522-
cirq.X,
523-
cirq.Y,
524-
cirq.Z,
525-
cirq.X**0.5,
526-
cirq.Y**0.5,
527-
cirq.Z**0.5,
528-
cirq.X**3.5,
529-
cirq.Y**3.5,
530-
cirq.Z**3.5,
531-
cirq.X**4,
532-
cirq.Y**4,
533-
cirq.Z**4,
534-
cirq.H,
535-
cirq.CX,
536-
cirq.CZ,
537-
cirq.H**4,
538-
cirq.CX**4,
539-
cirq.CZ**4,
540-
# Gates not supported by CliffordTableau should not fail too.
541-
cirq.X**0.25,
542-
cirq.Y**0.25,
543-
cirq.Z**0.25,
544-
cirq.H**0.5,
545-
cirq.CX**0.5,
546-
cirq.CZ**0.5
593+
(cirq.X, True),
594+
(cirq.Y, True),
595+
(cirq.Z, True),
596+
(cirq.X**0.5, True),
597+
(cirq.Y**0.5, True),
598+
(cirq.Z**0.5, True),
599+
(cirq.X**3.5, True),
600+
(cirq.Y**3.5, True),
601+
(cirq.Z**3.5, True),
602+
(cirq.X**4, True),
603+
(cirq.Y**4, True),
604+
(cirq.Z**4, True),
605+
(cirq.H, True),
606+
(cirq.CX, True),
607+
(cirq.CZ, True),
608+
(cirq.H**4, True),
609+
(cirq.CX**4, True),
610+
(cirq.CZ**4, True),
611+
# Unsupported gates should not fail too.
612+
(cirq.X**0.25, False),
613+
(cirq.Y**0.25, False),
614+
(cirq.Z**0.25, False),
615+
(cirq.H**0.5, False),
616+
(cirq.CX**0.5, False),
617+
(cirq.CZ**0.5, False),
547618
])
548-
def test_act_on_clifford_tableau(input_gate):
549-
cirq.testing.assert_act_on_clifford_tableau_effect_matches_unitary(
550-
input_gate)
619+
def test_act_on_consistency(input_gate, assert_implemented):
620+
cirq.testing.assert_all_implemented_act_on_effects_match_unitary(
621+
input_gate, assert_implemented, assert_implemented)
551622

552623

553624
def test_runtime_types_of_rot_gates():

0 commit comments

Comments
 (0)