Skip to content

Commit aa303dc

Browse files
authored
Deprecate Gateset.accept_global_phase_op (#5239)
Deprecates Gateset.accept_global_phase_op **Breaking Change:** Changes the default value of `Gateset.accept_global_phase_op` from `True` to `False`. I can't think of any way to remove this parameter without eventually needing this breaking change. Currently all gatesets that are created allow global phase gates if they don't specify `accept_global_phase_op=False` explicitly. But the end goal is only to allow global phase gates if they're included in the `gates` list. So at some point in the transition the default behavior needs to break, and I can't think of a way of doing that via deprecation. Therefore I think we may as well do it now via this breaking change. Note that even though it's breaking, it isn't breaking in a bad way. Users who are adding global phase gates to things that suddenly don't accept them will just see an error that the gate is not in the gateset, and then go add it. It's much safer than breaking in the other direction in which we silently start allowing new gate types. Closes #4741 @tanujkhattar
1 parent 0f995ee commit aa303dc

25 files changed

+215
-77
lines changed

cirq-core/cirq/ion/ion_device.py

-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def get_ion_gateset() -> ops.Gateset:
3030
ops.ZPowGate,
3131
ops.PhasedXPowGate,
3232
unroll_circuit_op=False,
33-
accept_global_phase_op=False,
3433
)
3534

3635

cirq-core/cirq/ion/ion_device_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def num_qubits(self):
166166

167167

168168
def test_can_add_operation_into_moment_device_deprecated():
169-
with cirq.testing.assert_deprecated('can_add_operation_into_moment', deadline='v0.15', count=5):
169+
with cirq.testing.assert_deprecated('can_add_operation_into_moment', deadline='v0.15', count=6):
170170
d = ion_device(3)
171171
q0 = cirq.LineQubit(0)
172172
q1 = cirq.LineQubit(1)
@@ -218,10 +218,10 @@ def test_at():
218218

219219

220220
def test_qubit_set_deprecated():
221-
with cirq.testing.assert_deprecated('qubit_set', deadline='v0.15'):
221+
with cirq.testing.assert_deprecated('qubit_set', deadline='v0.15', count=2):
222222
assert ion_device(3).qubit_set() == frozenset(cirq.LineQubit.range(3))
223223

224224

225225
def test_qid_pairs_deprecated():
226-
with cirq.testing.assert_deprecated('device.metadata', deadline='v0.15', count=1):
226+
with cirq.testing.assert_deprecated('device.metadata', deadline='v0.15', count=2):
227227
assert len(ion_device(10).qid_pairs()) == 45

cirq-core/cirq/neutral_atoms/neutral_atom_devices.py

-3
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def neutral_atom_gateset(max_parallel_z=None, max_parallel_xy=None):
4444
ops.MeasurementGate,
4545
ops.IdentityGate,
4646
unroll_circuit_op=False,
47-
accept_global_phase_op=False,
4847
)
4948

5049

@@ -100,15 +99,13 @@ def __init__(
10099
ops.ParallelGateFamily(ops.YPowGate),
101100
ops.ParallelGateFamily(ops.PhasedXPowGate),
102101
unroll_circuit_op=False,
103-
accept_global_phase_op=False,
104102
)
105103
self.controlled_gateset = ops.Gateset(
106104
ops.AnyIntegerPowerGateFamily(ops.CNotPowGate),
107105
ops.AnyIntegerPowerGateFamily(ops.CCNotPowGate),
108106
ops.AnyIntegerPowerGateFamily(ops.CZPowGate),
109107
ops.AnyIntegerPowerGateFamily(ops.CCZPowGate),
110108
unroll_circuit_op=False,
111-
accept_global_phase_op=False,
112109
)
113110
self.gateset = neutral_atom_gateset(max_parallel_z, max_parallel_xy)
114111
for q in qubits:

cirq-core/cirq/neutral_atoms/neutral_atom_devices_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def test_validate_moment_errors():
237237

238238

239239
def test_can_add_operation_into_moment_coverage_deprecated():
240-
with cirq.testing.assert_deprecated('can_add_operation_into_moment', deadline='v0.15', count=3):
240+
with cirq.testing.assert_deprecated('can_add_operation_into_moment', deadline='v0.15', count=4):
241241
d = square_device(2, 2)
242242
q00 = cirq.GridQubit(0, 0)
243243
q01 = cirq.GridQubit(0, 1)
@@ -298,5 +298,5 @@ def test_repr_pretty():
298298

299299

300300
def test_qubit_set_deprecated():
301-
with cirq.testing.assert_deprecated('qubit_set', deadline='v0.15'):
301+
with cirq.testing.assert_deprecated('qubit_set', deadline='v0.15', count=2):
302302
assert square_device(2, 2).qubit_set() == frozenset(cirq.GridQubit.square(2, 0, 0))

cirq-core/cirq/ops/gateset.py

+70-29
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414

1515
"""Functionality for grouping and validating Cirq Gates"""
1616

17+
import warnings
1718
from typing import Any, Callable, cast, Dict, FrozenSet, List, Optional, Type, TYPE_CHECKING, Union
19+
20+
from cirq import _compat, protocols, value
1821
from cirq.ops import global_phase_op, op_tree, raw_types
19-
from cirq import protocols, value
2022

2123
if TYPE_CHECKING:
2224
import cirq
@@ -201,12 +203,20 @@ class Gateset:
201203
validation purposes.
202204
"""
203205

206+
@_compat.deprecated_parameter(
207+
deadline='v0.16',
208+
fix='To accept global phase gates, add cirq.GlobalPhaseGate to the list of *gates passed '
209+
'to the constructor. By default, global phase gates will not be accepted by the '
210+
'gateset',
211+
parameter_desc='accept_global_phase_op',
212+
match=lambda args, kwargs: 'accept_global_phase_op' in kwargs,
213+
)
204214
def __init__(
205215
self,
206216
*gates: Union[Type[raw_types.Gate], raw_types.Gate, GateFamily],
207217
name: Optional[str] = None,
208218
unroll_circuit_op: bool = True,
209-
accept_global_phase_op: bool = True,
219+
accept_global_phase_op: Optional[bool] = None,
210220
) -> None:
211221
"""Init Gateset.
212222
@@ -225,17 +235,36 @@ def __init__(
225235
name: (Optional) Name for the Gateset. Useful for description.
226236
unroll_circuit_op: If True, `cirq.CircuitOperation` is recursively
227237
validated by validating the underlying `cirq.Circuit`.
228-
accept_global_phase_op: If True, `cirq.GlobalPhaseOperation` is accepted.
238+
accept_global_phase_op: If True, a `GateFamily` accepting
239+
`cirq.GlobalPhaseGate` will be included. If None,
240+
`cirq.GlobalPhaseGate` will not modify the input `*gates`.
241+
If False, `cirq.GlobalPhaseGate` will be removed from the
242+
gates. This parameter defaults to None (a breaking change from
243+
v0.14.1) and will be removed in v0.16.
229244
"""
230245
self._name = name
231246
self._unroll_circuit_op = unroll_circuit_op
232-
self._accept_global_phase_op = accept_global_phase_op
247+
if accept_global_phase_op:
248+
gates = gates + (global_phase_op.GlobalPhaseGate,)
233249
self._instance_gate_families: Dict[raw_types.Gate, GateFamily] = {}
234250
self._type_gate_families: Dict[Type[raw_types.Gate], GateFamily] = {}
235251
self._gates_repr_str = ", ".join([_gate_str(g, repr) for g in gates])
236252
unique_gate_list: List[GateFamily] = list(
237253
dict.fromkeys(g if isinstance(g, GateFamily) else GateFamily(gate=g) for g in gates)
238254
)
255+
if accept_global_phase_op is False:
256+
unique_gate_list = [
257+
g for g in unique_gate_list if g.gate is not global_phase_op.GlobalPhaseGate
258+
]
259+
elif accept_global_phase_op is None:
260+
if not any(g.gate is global_phase_op.GlobalPhaseGate for g in unique_gate_list):
261+
warnings.warn(
262+
'v0.14.1 is the last release `cirq.GlobalPhaseGate` is included by default. If'
263+
' you were relying on this behavior, you can include a `cirq.GlobalPhaseGate`'
264+
' in your `*gates`. If not, then you can ignore this warning. It will be'
265+
' removed in v0.16'
266+
)
267+
239268
for g in unique_gate_list:
240269
if type(g) == GateFamily:
241270
if isinstance(g.gate, raw_types.Gate):
@@ -253,6 +282,12 @@ def name(self) -> Optional[str]:
253282
def gates(self) -> FrozenSet[GateFamily]:
254283
return self._gates
255284

285+
@_compat.deprecated_parameter(
286+
deadline='v0.16',
287+
fix='Add a global phase gate to the Gateset',
288+
parameter_desc='accept_global_phase_op',
289+
match=lambda args, kwargs: 'accept_global_phase_op' in kwargs,
290+
)
256291
def with_params(
257292
self,
258293
*,
@@ -268,7 +303,12 @@ def with_params(
268303
name: New name for the Gateset.
269304
unroll_circuit_op: If True, new Gateset will recursively validate
270305
`cirq.CircuitOperation` by validating the underlying `cirq.Circuit`.
271-
accept_global_phase_op: If True, new Gateset will accept `cirq.GlobalPhaseOperation`.
306+
accept_global_phase_op: If True, a `GateFamily` accepting
307+
`cirq.GlobalPhaseGate` will be included. If None,
308+
`cirq.GlobalPhaseGate` will not modify the input `*gates`.
309+
If False, `cirq.GlobalPhaseGate` will be removed from the
310+
gates. This parameter defaults to None (a breaking change from
311+
v0.14.1) and will be removed in v0.16.
272312
273313
Returns:
274314
`self` if all new values are None or identical to the values of current Gateset.
@@ -280,19 +320,23 @@ def val_if_none(var: Any, val: Any) -> Any:
280320

281321
name = val_if_none(name, self._name)
282322
unroll_circuit_op = val_if_none(unroll_circuit_op, self._unroll_circuit_op)
283-
accept_global_phase_op = val_if_none(accept_global_phase_op, self._accept_global_phase_op)
323+
global_phase_family = GateFamily(gate=global_phase_op.GlobalPhaseGate)
284324
if (
285325
name == self._name
286326
and unroll_circuit_op == self._unroll_circuit_op
287-
and accept_global_phase_op == self._accept_global_phase_op
327+
and (
328+
accept_global_phase_op is True
329+
and global_phase_family in self.gates
330+
or accept_global_phase_op is False
331+
and not any(g.gate is global_phase_op.GlobalPhaseGate for g in self.gates)
332+
or accept_global_phase_op is None
333+
)
288334
):
289335
return self
290-
return Gateset(
291-
*self.gates,
292-
name=name,
293-
unroll_circuit_op=cast(bool, unroll_circuit_op),
294-
accept_global_phase_op=cast(bool, accept_global_phase_op),
295-
)
336+
gates = self.gates
337+
if accept_global_phase_op:
338+
gates = gates.union({global_phase_family})
339+
return Gateset(*gates, name=name, unroll_circuit_op=cast(bool, unroll_circuit_op))
296340

297341
def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool:
298342
"""Check for containment of a given Gate/Operation in this Gateset.
@@ -326,9 +370,6 @@ def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool
326370
g = item if isinstance(item, raw_types.Gate) else item.gate
327371
assert g is not None, f'`item`: {item} must be a gate or have a valid `item.gate`'
328372

329-
if isinstance(g, global_phase_op.GlobalPhaseGate):
330-
return self._accept_global_phase_op
331-
332373
if g in self._instance_gate_families:
333374
assert item in self._instance_gate_families[g], (
334375
f"{item} instance matches {self._instance_gate_families[g]} but "
@@ -394,16 +435,15 @@ def _validate_operation(self, op: raw_types.Operation) -> bool:
394435
return False
395436

396437
def _value_equality_values_(self) -> Any:
397-
return (self.gates, self.name, self._unroll_circuit_op, self._accept_global_phase_op)
438+
return (self.gates, self.name, self._unroll_circuit_op)
398439

399440
def __repr__(self) -> str:
400441
name_str = f'name = "{self.name}", ' if self.name is not None else ''
401442
return (
402443
f'cirq.Gateset('
403444
f'{self._gates_repr_str}, '
404445
f'{name_str}'
405-
f'unroll_circuit_op = {self._unroll_circuit_op},'
406-
f'accept_global_phase_op = {self._accept_global_phase_op})'
446+
f'unroll_circuit_op = {self._unroll_circuit_op})'
407447
)
408448

409449
def __str__(self) -> str:
@@ -417,16 +457,17 @@ def _json_dict_(self) -> Dict[str, Any]:
417457
'gates': self._unique_gate_list,
418458
'name': self.name,
419459
'unroll_circuit_op': self._unroll_circuit_op,
420-
'accept_global_phase_op': self._accept_global_phase_op,
421460
}
422461

423462
@classmethod
424-
def _from_json_dict_(
425-
cls, gates, name, unroll_circuit_op, accept_global_phase_op, **kwargs
426-
) -> 'Gateset':
427-
return cls(
428-
*gates,
429-
name=name,
430-
unroll_circuit_op=unroll_circuit_op,
431-
accept_global_phase_op=accept_global_phase_op,
432-
)
463+
def _from_json_dict_(cls, gates, name, unroll_circuit_op, **kwargs) -> 'Gateset':
464+
if 'accept_global_phase_op' in kwargs:
465+
accept_global_phase_op = kwargs['accept_global_phase_op']
466+
global_phase_family = GateFamily(gate=global_phase_op.GlobalPhaseGate)
467+
if accept_global_phase_op is True:
468+
gates.append(global_phase_family)
469+
elif accept_global_phase_op is False:
470+
gates = [
471+
family for family in gates if family.gate is not global_phase_op.GlobalPhaseGate
472+
]
473+
return cls(*gates, name=name, unroll_circuit_op=unroll_circuit_op)

cirq-core/cirq/ops/gateset_test.py

+35-25
Original file line numberDiff line numberDiff line change
@@ -257,19 +257,21 @@ def assert_validate_and_contains_consistent(gateset, op_tree, result):
257257
assert gateset.validate(item) is result
258258

259259
op_tree = [*get_ops(use_circuit_op, use_global_phase)]
260-
assert_validate_and_contains_consistent(
261-
gateset.with_params(
262-
unroll_circuit_op=use_circuit_op, accept_global_phase_op=use_global_phase
263-
),
264-
op_tree,
265-
True,
266-
)
267-
if use_circuit_op or use_global_phase:
260+
with cirq.testing.assert_deprecated('global phase', deadline='v0.16', count=None):
268261
assert_validate_and_contains_consistent(
269-
gateset.with_params(unroll_circuit_op=False, accept_global_phase_op=False),
262+
gateset.with_params(
263+
unroll_circuit_op=use_circuit_op, accept_global_phase_op=use_global_phase
264+
),
270265
op_tree,
271-
False,
266+
True,
272267
)
268+
if use_circuit_op or use_global_phase:
269+
with cirq.testing.assert_deprecated('global phase', deadline='v0.16', count=2):
270+
assert_validate_and_contains_consistent(
271+
gateset.with_params(unroll_circuit_op=False, accept_global_phase_op=False),
272+
op_tree,
273+
False,
274+
)
273275

274276

275277
def test_gateset_validate_circuit_op_negative_reps():
@@ -281,31 +283,39 @@ def test_gateset_validate_circuit_op_negative_reps():
281283

282284
def test_with_params():
283285
assert gateset.with_params() is gateset
284-
assert (
285-
gateset.with_params(
286-
name=gateset.name,
287-
unroll_circuit_op=gateset._unroll_circuit_op,
288-
accept_global_phase_op=gateset._accept_global_phase_op,
286+
with cirq.testing.assert_deprecated('global phase', deadline='v0.16'):
287+
assert (
288+
gateset.with_params(
289+
name=gateset.name,
290+
unroll_circuit_op=gateset._unroll_circuit_op,
291+
accept_global_phase_op=None,
292+
)
293+
is gateset
294+
)
295+
with cirq.testing.assert_deprecated('global phase', deadline='v0.16', count=2):
296+
gateset_with_params = gateset.with_params(
297+
name='new name', unroll_circuit_op=False, accept_global_phase_op=False
289298
)
290-
is gateset
291-
)
292-
gateset_with_params = gateset.with_params(
293-
name='new name', unroll_circuit_op=False, accept_global_phase_op=False
294-
)
295299
assert gateset_with_params.name == 'new name'
296300
assert gateset_with_params._unroll_circuit_op is False
297-
assert gateset_with_params._accept_global_phase_op is False
298301

299302

300303
def test_gateset_eq():
301304
eq = cirq.testing.EqualsTester()
302305
eq.add_equality_group(cirq.Gateset(CustomX))
303306
eq.add_equality_group(cirq.Gateset(CustomX**3))
304-
eq.add_equality_group(cirq.Gateset(CustomX, name='Custom Gateset'))
307+
with cirq.testing.assert_deprecated('global phase', deadline='v0.16'):
308+
eq.add_equality_group(
309+
cirq.Gateset(CustomX, name='Custom Gateset'),
310+
cirq.Gateset(
311+
CustomX, cirq.GlobalPhaseGate, name='Custom Gateset', accept_global_phase_op=False
312+
),
313+
)
305314
eq.add_equality_group(cirq.Gateset(CustomX, name='Custom Gateset', unroll_circuit_op=False))
306-
eq.add_equality_group(
307-
cirq.Gateset(CustomX, name='Custom Gateset', accept_global_phase_op=False)
308-
)
315+
with cirq.testing.assert_deprecated('global phase', deadline='v0.16'):
316+
eq.add_equality_group(
317+
cirq.Gateset(CustomX, name='Custom Gateset', accept_global_phase_op=True)
318+
)
309319
eq.add_equality_group(
310320
cirq.Gateset(
311321
cirq.GateFamily(CustomX, name='custom_name', description='custom_description'),

cirq-core/cirq/optimizers/convert_to_cz_and_single_gates.py

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(self, ignore_failures: bool = False, allow_partial_czs: bool = Fals
5151
ops.CZPowGate if allow_partial_czs else ops.CZ,
5252
ops.MeasurementGate,
5353
ops.AnyUnitaryGateFamily(1),
54+
ops.GlobalPhaseGate,
5455
)
5556

5657
def _decompose_two_qubit_unitaries(self, op: ops.Operation) -> ops.OP_TREE:

cirq-core/cirq/optimizers/merge_interactions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ def __init__(
235235
self.allow_partial_czs = allow_partial_czs
236236
self.gateset = ops.Gateset(
237237
ops.CZPowGate if allow_partial_czs else ops.CZ,
238+
ops.GlobalPhaseGate,
238239
unroll_circuit_op=False,
239-
accept_global_phase_op=True,
240240
)
241241

242242
def _may_keep_old_op(self, old_op: 'cirq.Operation') -> bool:

cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def __init__(
7676
self.use_sqrt_iswap_inv = use_sqrt_iswap_inv
7777
self.gateset = ops.Gateset(
7878
ops.SQRT_ISWAP_INV if use_sqrt_iswap_inv else ops.SQRT_ISWAP,
79+
ops.GlobalPhaseGate,
7980
unroll_circuit_op=False,
80-
accept_global_phase_op=True,
8181
)
8282

8383
def _may_keep_old_op(self, old_op: 'cirq.Operation') -> bool:

cirq-core/cirq/protocols/json_test_data/Gateset.json

+2-4
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@
2626
}
2727
],
2828
"name": null,
29-
"unroll_circuit_op": true,
30-
"accept_global_phase_op": true
29+
"unroll_circuit_op": true
3130
},
3231
{
3332
"cirq_type": "Gateset",
@@ -56,7 +55,6 @@
5655
}
5756
],
5857
"name": "Custom Name",
59-
"unroll_circuit_op": false,
60-
"accept_global_phase_op": false
58+
"unroll_circuit_op": false
6159
}
6260
]

0 commit comments

Comments
 (0)