From dc0c8262d5cebaefa8511b79ea5e64a4de713d24 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Mon, 6 Feb 2023 20:34:48 -0800 Subject: [PATCH 1/5] Add __cirq_debug__ flag and conditionally disable qid validations in gates and operations --- asv.conf.json | 1 + cirq-core/cirq/__init__.py | 2 ++ cirq-core/cirq/_compat.py | 11 ++++++++ cirq-core/cirq/ops/raw_types.py | 40 +++++++++++++++++++--------- cirq-core/cirq/ops/raw_types_test.py | 34 +++++++++++++++++++++++ cirq-core/requirements.txt | 1 + 6 files changed, 76 insertions(+), 13 deletions(-) diff --git a/asv.conf.json b/asv.conf.json index 35813c17405..e5c94df2f88 100644 --- a/asv.conf.json +++ b/asv.conf.json @@ -9,6 +9,7 @@ "environment_type": "virtualenv", "show_commit_url": "https://github.com/quantumlib/Cirq/commit/", "pythons": ["3.8"], + "matrix": {"env_nobuild": {"PYTHONOPTIMIZE": ["-O", ""]}}, "benchmark_dir": "benchmarks", "env_dir": ".asv/env", "results_dir": ".asv/results", diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index d68fc645b25..55ff0edf484 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -16,6 +16,8 @@ from cirq import _import +from cirq._compat import __cirq_debug__ + # A module can only depend on modules imported earlier in this list of modules # at import time. Pytest will fail otherwise (enforced by # dev_tools/import_test.py). diff --git a/cirq-core/cirq/_compat.py b/cirq-core/cirq/_compat.py index cf123d0408b..cb79ef5e905 100644 --- a/cirq-core/cirq/_compat.py +++ b/cirq-core/cirq/_compat.py @@ -15,6 +15,8 @@ """Workarounds for compatibility issues between versions and libraries.""" import contextlib import dataclasses + +import contextvars import functools import importlib import inspect @@ -31,8 +33,17 @@ import sympy import sympy.printing.repr +from cirq._doc import document + ALLOW_DEPRECATION_IN_TEST = 'ALLOW_DEPRECATION_IN_TEST' +__cirq_debug__ = contextvars.ContextVar('__cirq_debug__', default=__debug__) +document( + __cirq_debug__, + "A cirq specific flag which can be used to conditionally turn off all validations across Cirq " + "to boost performance in production mode. Defaults to python's built-in constant __debug__. " + "The flag is implemented as a `ContextVar` and is thread safe.", +) try: from functools import cached_property # pylint: disable=unused-import diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index 3dcd6b2004a..63601487cb8 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -40,6 +40,7 @@ from cirq import protocols, value from cirq._import import LazyLoader +from cirq._compat import __cirq_debug__ from cirq.type_workarounds import NotImplementedType from cirq.ops import control_values as cv @@ -215,7 +216,8 @@ def validate_args(self, qubits: Sequence['cirq.Qid']) -> None: Raises: ValueError: The gate can't be applied to the qubits. """ - _validate_qid_shape(self, qubits) + if __cirq_debug__.get(): + _validate_qid_shape(self, qubits) def on(self, *qubits: Qid) -> 'Operation': """Returns an application of this gate to the given qubits. @@ -254,19 +256,30 @@ def on_each(self, *targets: Union[Qid, Iterable[Any]]) -> List['cirq.Operation'] raise TypeError(f'{targets[0]} object is not iterable.') t0 = list(targets[0]) iterator = [t0] if t0 and isinstance(t0[0], Qid) else t0 - for target in iterator: - if not isinstance(target, Sequence): - raise ValueError( - f'Inputs to multi-qubit gates must be Sequence[Qid].' - f' Type: {type(target)}' - ) - if not all(isinstance(x, Qid) for x in target): - raise ValueError(f'All values in sequence should be Qids, but got {target}') - if len(target) != self._num_qubits_(): - raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}') - operations.append(self.on(*target)) + if __cirq_debug__.get(): + for target in iterator: + if not isinstance(target, Sequence): + raise ValueError( + f'Inputs to multi-qubit gates must be Sequence[Qid].' + f' Type: {type(target)}' + ) + if not all(isinstance(x, Qid) for x in target): + raise ValueError(f'All values in sequence should be Qids, but got {target}') + if len(target) != self._num_qubits_(): + raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}') + operations.append(self.on(*target)) + else: + operations = [self.on(*target) for target in iterator] return operations + if __cirq_debug__.get() is False: + return [ + self.on_each(*q) + if isinstance(q, Iterable) and not isinstance(q, str) + else self.on(q) + for q in targets + ] + for target in targets: if isinstance(target, Qid): operations.append(self.on(target)) @@ -617,7 +630,8 @@ def validate_args(self, qubits: Sequence['cirq.Qid']): Raises: ValueError: The operation had qids that don't match it's qid shape. """ - _validate_qid_shape(self, qubits) + if __cirq_debug__.get(): + _validate_qid_shape(self, qubits) def _commutes_( self, other: Any, *, atol: float = 1e-8 diff --git a/cirq-core/cirq/ops/raw_types_test.py b/cirq-core/cirq/ops/raw_types_test.py index 2235f43742f..b8b77f5d257 100644 --- a/cirq-core/cirq/ops/raw_types_test.py +++ b/cirq-core/cirq/ops/raw_types_test.py @@ -151,6 +151,30 @@ def test_op_validate(): op2.validate_args([cirq.LineQid(1, 2), cirq.LineQid(1, 2)]) +def test_disable_op_validation(): + q0, q1 = cirq.LineQubit.range(2) + h_op = cirq.H(q0) + + # Fails normally. + with pytest.raises(ValueError, match='Wrong number'): + _ = cirq.H(q0, q1) + with pytest.raises(ValueError, match='Wrong number'): + h_op.validate_args([q0, q1]) + + # Passes, skipping validation. + cirq.__cirq_debug__.set(False) + op = cirq.H(q0, q1) + assert op.qubits == (q0, q1) + h_op.validate_args([q0, q1]) + + # Fails again when validation is re-enabled. + cirq.__cirq_debug__.set(True) + with pytest.raises(ValueError, match='Wrong number'): + _ = cirq.H(q0, q1) + with pytest.raises(ValueError, match='Wrong number'): + h_op.validate_args([q0, q1]) + + def test_default_validation_and_inverse(): class TestGate(cirq.Gate): def _num_qubits_(self): @@ -787,6 +811,11 @@ def matrix(self): test_non_qubits = [str(i) for i in range(3)] with pytest.raises(ValueError): _ = g.on_each(*test_non_qubits) + + cirq.__cirq_debug__.set(False) + assert g.on_each(*test_non_qubits)[0].qubits == ('0',) + + cirq.__cirq_debug__.set(True) with pytest.raises(ValueError): _ = g.on_each(*test_non_qubits) @@ -853,6 +882,11 @@ def test_on_each_two_qubits(): g.on_each([(a,)]) with pytest.raises(ValueError, match='Expected 2 qubits'): g.on_each([(a, b, a)]) + + cirq.__cirq_debug__.set(False) + assert g.on_each([(a, b, a)])[0].qubits == (a, b, a) + cirq.__cirq_debug__.set(True) + with pytest.raises(ValueError, match='Expected 2 qubits'): g.on_each(zip([a, a])) with pytest.raises(ValueError, match='Expected 2 qubits'): diff --git a/cirq-core/requirements.txt b/cirq-core/requirements.txt index e26eef68637..eb338eaad0b 100644 --- a/cirq-core/requirements.txt +++ b/cirq-core/requirements.txt @@ -3,6 +3,7 @@ # functools.cached_property was introduced in python 3.8 backports.cached_property~=1.0.1; python_version < '3.8' +contextvars duet~=0.2.7 matplotlib~=3.0 networkx~=2.4 From 8fcfdf86fad19c60c5a90efe817ab616fa8f9e81 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Mon, 6 Feb 2023 23:32:28 -0800 Subject: [PATCH 2/5] fix mypy errors --- cirq-core/cirq/ops/raw_types.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index 63601487cb8..f3e9f444df7 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -17,6 +17,7 @@ import abc import functools from typing import ( + cast, AbstractSet, Any, Callable, @@ -274,10 +275,13 @@ def on_each(self, *targets: Union[Qid, Iterable[Any]]) -> List['cirq.Operation'] if __cirq_debug__.get() is False: return [ - self.on_each(*q) - if isinstance(q, Iterable) and not isinstance(q, str) - else self.on(q) + op for q in targets + for op in ( + self.on_each(*q) + if isinstance(q, Iterable) and not isinstance(q, str) + else [self.on(cast(cirq.Qid, q))] + ) ] for target in targets: From 48b31a232f73717e9f8e772f64d1a637e36aacfe Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Tue, 7 Feb 2023 00:06:38 -0800 Subject: [PATCH 3/5] Fix typo --- cirq-core/cirq/ops/raw_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index f3e9f444df7..d0362f99165 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -280,7 +280,7 @@ def on_each(self, *targets: Union[Qid, Iterable[Any]]) -> List['cirq.Operation'] for op in ( self.on_each(*q) if isinstance(q, Iterable) and not isinstance(q, str) - else [self.on(cast(cirq.Qid, q))] + else [self.on(cast('cirq.Qid', q))] ) ] From 5a20969d836a10f1b059937bb85a0114d421401d Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Tue, 7 Feb 2023 20:06:47 -0800 Subject: [PATCH 4/5] Address comments and add a context manager --- cirq-core/cirq/__init__.py | 2 +- cirq-core/cirq/_compat.py | 19 ++++++++++++++++++- cirq-core/cirq/_compat_test.py | 10 ++++++++++ cirq-core/cirq/ops/raw_types.py | 2 +- cirq-core/cirq/ops/raw_types_test.py | 19 ++++++++----------- cirq-core/requirements.txt | 1 - 6 files changed, 38 insertions(+), 15 deletions(-) diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 55ff0edf484..b246ccdbf38 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -16,7 +16,7 @@ from cirq import _import -from cirq._compat import __cirq_debug__ +from cirq._compat import __cirq_debug__, with_debug # A module can only depend on modules imported earlier in this list of modules # at import time. Pytest will fail otherwise (enforced by diff --git a/cirq-core/cirq/_compat.py b/cirq-core/cirq/_compat.py index cb79ef5e905..53cc6bb0ea1 100644 --- a/cirq-core/cirq/_compat.py +++ b/cirq-core/cirq/_compat.py @@ -26,7 +26,7 @@ import traceback import warnings from types import ModuleType -from typing import Any, Callable, Dict, Optional, overload, Set, Tuple, Type, TypeVar +from typing import Any, Callable, Dict, Iterator, Optional, overload, Set, Tuple, Type, TypeVar import numpy as np import pandas as pd @@ -45,6 +45,23 @@ "The flag is implemented as a `ContextVar` and is thread safe.", ) + +@contextlib.contextmanager +def with_debug(value: bool) -> Iterator[None]: + """Sets the value of global constant `cirq.__cirq_debug__` within the context. + + If `__cirq_debug__` is set to False, all validations in Cirq are disabled to optimize + performance. Users should use the `cirq.with_debug` context manager instead of manually + mutating the value of `__cirq_debug__` flag. On exit, the context manager resets the + value of `__cirq_debug__` flag to what it was before entering the context manager. + """ + token = __cirq_debug__.set(value) + try: + yield + finally: + __cirq_debug__.reset(token) + + try: from functools import cached_property # pylint: disable=unused-import except ImportError: diff --git a/cirq-core/cirq/_compat_test.py b/cirq-core/cirq/_compat_test.py index 7515ec9f2b4..94505535dc8 100644 --- a/cirq-core/cirq/_compat_test.py +++ b/cirq-core/cirq/_compat_test.py @@ -51,6 +51,16 @@ ) +def test_with_debug(): + assert cirq.__cirq_debug__.get() + with cirq.with_debug(False): + assert not cirq.__cirq_debug__.get() + with cirq.with_debug(True): + assert cirq.__cirq_debug__.get() + assert not cirq.__cirq_debug__.get() + assert cirq.__cirq_debug__.get() + + def test_proper_repr(): v = sympy.Symbol('t') * 3 v2 = eval(proper_repr(v)) diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index d0362f99165..993cf3cc9ea 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -273,7 +273,7 @@ def on_each(self, *targets: Union[Qid, Iterable[Any]]) -> List['cirq.Operation'] operations = [self.on(*target) for target in iterator] return operations - if __cirq_debug__.get() is False: + if not __cirq_debug__.get(): return [ op for q in targets diff --git a/cirq-core/cirq/ops/raw_types_test.py b/cirq-core/cirq/ops/raw_types_test.py index b8b77f5d257..81ab8f61935 100644 --- a/cirq-core/cirq/ops/raw_types_test.py +++ b/cirq-core/cirq/ops/raw_types_test.py @@ -162,13 +162,12 @@ def test_disable_op_validation(): h_op.validate_args([q0, q1]) # Passes, skipping validation. - cirq.__cirq_debug__.set(False) - op = cirq.H(q0, q1) - assert op.qubits == (q0, q1) - h_op.validate_args([q0, q1]) + with cirq.with_debug(False): + op = cirq.H(q0, q1) + assert op.qubits == (q0, q1) + h_op.validate_args([q0, q1]) # Fails again when validation is re-enabled. - cirq.__cirq_debug__.set(True) with pytest.raises(ValueError, match='Wrong number'): _ = cirq.H(q0, q1) with pytest.raises(ValueError, match='Wrong number'): @@ -812,10 +811,9 @@ def matrix(self): with pytest.raises(ValueError): _ = g.on_each(*test_non_qubits) - cirq.__cirq_debug__.set(False) - assert g.on_each(*test_non_qubits)[0].qubits == ('0',) + with cirq.with_debug(False): + assert g.on_each(*test_non_qubits)[0].qubits == ('0',) - cirq.__cirq_debug__.set(True) with pytest.raises(ValueError): _ = g.on_each(*test_non_qubits) @@ -883,9 +881,8 @@ def test_on_each_two_qubits(): with pytest.raises(ValueError, match='Expected 2 qubits'): g.on_each([(a, b, a)]) - cirq.__cirq_debug__.set(False) - assert g.on_each([(a, b, a)])[0].qubits == (a, b, a) - cirq.__cirq_debug__.set(True) + with cirq.with_debug(False): + assert g.on_each([(a, b, a)])[0].qubits == (a, b, a) with pytest.raises(ValueError, match='Expected 2 qubits'): g.on_each(zip([a, a])) diff --git a/cirq-core/requirements.txt b/cirq-core/requirements.txt index eb338eaad0b..e26eef68637 100644 --- a/cirq-core/requirements.txt +++ b/cirq-core/requirements.txt @@ -3,7 +3,6 @@ # functools.cached_property was introduced in python 3.8 backports.cached_property~=1.0.1; python_version < '3.8' -contextvars duet~=0.2.7 matplotlib~=3.0 networkx~=2.4 From 6fad36a2097e880e55672df6ddd30c40e4ffbb7e Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Wed, 15 Feb 2023 16:00:07 -0800 Subject: [PATCH 5/5] Address nit --- cirq-core/cirq/_compat.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cirq-core/cirq/_compat.py b/cirq-core/cirq/_compat.py index 53cc6bb0ea1..a39833abbfe 100644 --- a/cirq-core/cirq/_compat.py +++ b/cirq-core/cirq/_compat.py @@ -14,9 +14,8 @@ """Workarounds for compatibility issues between versions and libraries.""" import contextlib -import dataclasses - import contextvars +import dataclasses import functools import importlib import inspect