Skip to content

Commit 27dd607

Browse files
authored
Add __cirq_debug__ flag and conditionally disable qid validations in gates and operations (#6000)
* Add __cirq_debug__ flag and conditionally disable qid validations in gates and operations * fix mypy errors * Fix typo * Address comments and add a context manager * Address nit
1 parent fd491e0 commit 27dd607

File tree

6 files changed

+103
-14
lines changed

6 files changed

+103
-14
lines changed

Diff for: asv.conf.json

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"environment_type": "virtualenv",
1010
"show_commit_url": "https://github.com/quantumlib/Cirq/commit/",
1111
"pythons": ["3.8"],
12+
"matrix": {"env_nobuild": {"PYTHONOPTIMIZE": ["-O", ""]}},
1213
"benchmark_dir": "benchmarks",
1314
"env_dir": ".asv/env",
1415
"results_dir": ".asv/results",

Diff for: cirq-core/cirq/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from cirq import _import
1818

19+
from cirq._compat import __cirq_debug__, with_debug
20+
1921
# A module can only depend on modules imported earlier in this list of modules
2022
# at import time. Pytest will fail otherwise (enforced by
2123
# dev_tools/import_test.py).

Diff for: cirq-core/cirq/_compat.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Workarounds for compatibility issues between versions and libraries."""
1616
import contextlib
17+
import contextvars
1718
import dataclasses
1819
import functools
1920
import importlib
@@ -24,15 +25,41 @@
2425
import traceback
2526
import warnings
2627
from types import ModuleType
27-
from typing import Any, Callable, Dict, Optional, overload, Set, Tuple, Type, TypeVar
28+
from typing import Any, Callable, Dict, Iterator, Optional, overload, Set, Tuple, Type, TypeVar
2829

2930
import numpy as np
3031
import pandas as pd
3132
import sympy
3233
import sympy.printing.repr
3334

35+
from cirq._doc import document
36+
3437
ALLOW_DEPRECATION_IN_TEST = 'ALLOW_DEPRECATION_IN_TEST'
3538

39+
__cirq_debug__ = contextvars.ContextVar('__cirq_debug__', default=__debug__)
40+
document(
41+
__cirq_debug__,
42+
"A cirq specific flag which can be used to conditionally turn off all validations across Cirq "
43+
"to boost performance in production mode. Defaults to python's built-in constant __debug__. "
44+
"The flag is implemented as a `ContextVar` and is thread safe.",
45+
)
46+
47+
48+
@contextlib.contextmanager
49+
def with_debug(value: bool) -> Iterator[None]:
50+
"""Sets the value of global constant `cirq.__cirq_debug__` within the context.
51+
52+
If `__cirq_debug__` is set to False, all validations in Cirq are disabled to optimize
53+
performance. Users should use the `cirq.with_debug` context manager instead of manually
54+
mutating the value of `__cirq_debug__` flag. On exit, the context manager resets the
55+
value of `__cirq_debug__` flag to what it was before entering the context manager.
56+
"""
57+
token = __cirq_debug__.set(value)
58+
try:
59+
yield
60+
finally:
61+
__cirq_debug__.reset(token)
62+
3663

3764
try:
3865
from functools import cached_property # pylint: disable=unused-import

Diff for: cirq-core/cirq/_compat_test.py

+10
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,16 @@
5151
)
5252

5353

54+
def test_with_debug():
55+
assert cirq.__cirq_debug__.get()
56+
with cirq.with_debug(False):
57+
assert not cirq.__cirq_debug__.get()
58+
with cirq.with_debug(True):
59+
assert cirq.__cirq_debug__.get()
60+
assert not cirq.__cirq_debug__.get()
61+
assert cirq.__cirq_debug__.get()
62+
63+
5464
def test_proper_repr():
5565
v = sympy.Symbol('t') * 3
5666
v2 = eval(proper_repr(v))

Diff for: cirq-core/cirq/ops/raw_types.py

+31-13
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import abc
1818
import functools
1919
from typing import (
20+
cast,
2021
AbstractSet,
2122
Any,
2223
Callable,
@@ -40,6 +41,7 @@
4041

4142
from cirq import protocols, value
4243
from cirq._import import LazyLoader
44+
from cirq._compat import __cirq_debug__
4345
from cirq.type_workarounds import NotImplementedType
4446
from cirq.ops import control_values as cv
4547

@@ -215,7 +217,8 @@ def validate_args(self, qubits: Sequence['cirq.Qid']) -> None:
215217
Raises:
216218
ValueError: The gate can't be applied to the qubits.
217219
"""
218-
_validate_qid_shape(self, qubits)
220+
if __cirq_debug__.get():
221+
_validate_qid_shape(self, qubits)
219222

220223
def on(self, *qubits: Qid) -> 'Operation':
221224
"""Returns an application of this gate to the given qubits.
@@ -254,19 +257,33 @@ def on_each(self, *targets: Union[Qid, Iterable[Any]]) -> List['cirq.Operation']
254257
raise TypeError(f'{targets[0]} object is not iterable.')
255258
t0 = list(targets[0])
256259
iterator = [t0] if t0 and isinstance(t0[0], Qid) else t0
257-
for target in iterator:
258-
if not isinstance(target, Sequence):
259-
raise ValueError(
260-
f'Inputs to multi-qubit gates must be Sequence[Qid].'
261-
f' Type: {type(target)}'
262-
)
263-
if not all(isinstance(x, Qid) for x in target):
264-
raise ValueError(f'All values in sequence should be Qids, but got {target}')
265-
if len(target) != self._num_qubits_():
266-
raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}')
267-
operations.append(self.on(*target))
260+
if __cirq_debug__.get():
261+
for target in iterator:
262+
if not isinstance(target, Sequence):
263+
raise ValueError(
264+
f'Inputs to multi-qubit gates must be Sequence[Qid].'
265+
f' Type: {type(target)}'
266+
)
267+
if not all(isinstance(x, Qid) for x in target):
268+
raise ValueError(f'All values in sequence should be Qids, but got {target}')
269+
if len(target) != self._num_qubits_():
270+
raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}')
271+
operations.append(self.on(*target))
272+
else:
273+
operations = [self.on(*target) for target in iterator]
268274
return operations
269275

276+
if not __cirq_debug__.get():
277+
return [
278+
op
279+
for q in targets
280+
for op in (
281+
self.on_each(*q)
282+
if isinstance(q, Iterable) and not isinstance(q, str)
283+
else [self.on(cast('cirq.Qid', q))]
284+
)
285+
]
286+
270287
for target in targets:
271288
if isinstance(target, Qid):
272289
operations.append(self.on(target))
@@ -617,7 +634,8 @@ def validate_args(self, qubits: Sequence['cirq.Qid']):
617634
Raises:
618635
ValueError: The operation had qids that don't match it's qid shape.
619636
"""
620-
_validate_qid_shape(self, qubits)
637+
if __cirq_debug__.get():
638+
_validate_qid_shape(self, qubits)
621639

622640
def _commutes_(
623641
self, other: Any, *, atol: float = 1e-8

Diff for: cirq-core/cirq/ops/raw_types_test.py

+31
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,29 @@ def test_op_validate():
151151
op2.validate_args([cirq.LineQid(1, 2), cirq.LineQid(1, 2)])
152152

153153

154+
def test_disable_op_validation():
155+
q0, q1 = cirq.LineQubit.range(2)
156+
h_op = cirq.H(q0)
157+
158+
# Fails normally.
159+
with pytest.raises(ValueError, match='Wrong number'):
160+
_ = cirq.H(q0, q1)
161+
with pytest.raises(ValueError, match='Wrong number'):
162+
h_op.validate_args([q0, q1])
163+
164+
# Passes, skipping validation.
165+
with cirq.with_debug(False):
166+
op = cirq.H(q0, q1)
167+
assert op.qubits == (q0, q1)
168+
h_op.validate_args([q0, q1])
169+
170+
# Fails again when validation is re-enabled.
171+
with pytest.raises(ValueError, match='Wrong number'):
172+
_ = cirq.H(q0, q1)
173+
with pytest.raises(ValueError, match='Wrong number'):
174+
h_op.validate_args([q0, q1])
175+
176+
154177
def test_default_validation_and_inverse():
155178
class TestGate(cirq.Gate):
156179
def _num_qubits_(self):
@@ -787,6 +810,10 @@ def matrix(self):
787810
test_non_qubits = [str(i) for i in range(3)]
788811
with pytest.raises(ValueError):
789812
_ = g.on_each(*test_non_qubits)
813+
814+
with cirq.with_debug(False):
815+
assert g.on_each(*test_non_qubits)[0].qubits == ('0',)
816+
790817
with pytest.raises(ValueError):
791818
_ = g.on_each(*test_non_qubits)
792819

@@ -853,6 +880,10 @@ def test_on_each_two_qubits():
853880
g.on_each([(a,)])
854881
with pytest.raises(ValueError, match='Expected 2 qubits'):
855882
g.on_each([(a, b, a)])
883+
884+
with cirq.with_debug(False):
885+
assert g.on_each([(a, b, a)])[0].qubits == (a, b, a)
886+
856887
with pytest.raises(ValueError, match='Expected 2 qubits'):
857888
g.on_each(zip([a, a]))
858889
with pytest.raises(ValueError, match='Expected 2 qubits'):

0 commit comments

Comments
 (0)