Skip to content

Add __cirq_debug__ flag and conditionally disable qid validations in gates and operations #6000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions asv.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"environment_type": "virtualenv",
"show_commit_url": "https://github.com/quantumlib/Cirq/commit/",
"pythons": ["3.8"],
"matrix": {"env_nobuild": {"PYTHONOPTIMIZE": ["-O", ""]}},
"benchmark_dir": "benchmarks",
"env_dir": ".asv/env",
"results_dir": ".asv/results",
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from cirq import _import

from cirq._compat import __cirq_debug__, with_debug

# A module can only depend on modules imported earlier in this list of modules
# at import time. Pytest will fail otherwise (enforced by
# dev_tools/import_test.py).
Expand Down
29 changes: 28 additions & 1 deletion cirq-core/cirq/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Workarounds for compatibility issues between versions and libraries."""
import contextlib
import contextvars
import dataclasses
import functools
import importlib
Expand All @@ -24,15 +25,41 @@
import traceback
import warnings
from types import ModuleType
from typing import Any, Callable, Dict, Optional, overload, Set, Tuple, Type, TypeVar
from typing import Any, Callable, Dict, Iterator, Optional, overload, Set, Tuple, Type, TypeVar

import numpy as np
import pandas as pd
import sympy
import sympy.printing.repr

from cirq._doc import document

ALLOW_DEPRECATION_IN_TEST = 'ALLOW_DEPRECATION_IN_TEST'

__cirq_debug__ = contextvars.ContextVar('__cirq_debug__', default=__debug__)
document(
__cirq_debug__,
"A cirq specific flag which can be used to conditionally turn off all validations across Cirq "
"to boost performance in production mode. Defaults to python's built-in constant __debug__. "
"The flag is implemented as a `ContextVar` and is thread safe.",
)


@contextlib.contextmanager
def with_debug(value: bool) -> Iterator[None]:
"""Sets the value of global constant `cirq.__cirq_debug__` within the context.

If `__cirq_debug__` is set to False, all validations in Cirq are disabled to optimize
performance. Users should use the `cirq.with_debug` context manager instead of manually
mutating the value of `__cirq_debug__` flag. On exit, the context manager resets the
value of `__cirq_debug__` flag to what it was before entering the context manager.
"""
token = __cirq_debug__.set(value)
try:
yield
finally:
__cirq_debug__.reset(token)


try:
from functools import cached_property # pylint: disable=unused-import
Expand Down
10 changes: 10 additions & 0 deletions cirq-core/cirq/_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@
)


def test_with_debug():
assert cirq.__cirq_debug__.get()
with cirq.with_debug(False):
assert not cirq.__cirq_debug__.get()
with cirq.with_debug(True):
assert cirq.__cirq_debug__.get()
assert not cirq.__cirq_debug__.get()
assert cirq.__cirq_debug__.get()


def test_proper_repr():
v = sympy.Symbol('t') * 3
v2 = eval(proper_repr(v))
Expand Down
44 changes: 31 additions & 13 deletions cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import abc
import functools
from typing import (
cast,
AbstractSet,
Any,
Callable,
Expand All @@ -40,6 +41,7 @@

from cirq import protocols, value
from cirq._import import LazyLoader
from cirq._compat import __cirq_debug__
from cirq.type_workarounds import NotImplementedType
from cirq.ops import control_values as cv

Expand Down Expand Up @@ -215,7 +217,8 @@ def validate_args(self, qubits: Sequence['cirq.Qid']) -> None:
Raises:
ValueError: The gate can't be applied to the qubits.
"""
_validate_qid_shape(self, qubits)
if __cirq_debug__.get():
_validate_qid_shape(self, qubits)

def on(self, *qubits: Qid) -> 'Operation':
"""Returns an application of this gate to the given qubits.
Expand Down Expand Up @@ -254,19 +257,33 @@ def on_each(self, *targets: Union[Qid, Iterable[Any]]) -> List['cirq.Operation']
raise TypeError(f'{targets[0]} object is not iterable.')
t0 = list(targets[0])
iterator = [t0] if t0 and isinstance(t0[0], Qid) else t0
for target in iterator:
if not isinstance(target, Sequence):
raise ValueError(
f'Inputs to multi-qubit gates must be Sequence[Qid].'
f' Type: {type(target)}'
)
if not all(isinstance(x, Qid) for x in target):
raise ValueError(f'All values in sequence should be Qids, but got {target}')
if len(target) != self._num_qubits_():
raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}')
operations.append(self.on(*target))
if __cirq_debug__.get():
for target in iterator:
if not isinstance(target, Sequence):
raise ValueError(
f'Inputs to multi-qubit gates must be Sequence[Qid].'
f' Type: {type(target)}'
)
if not all(isinstance(x, Qid) for x in target):
raise ValueError(f'All values in sequence should be Qids, but got {target}')
if len(target) != self._num_qubits_():
raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}')
operations.append(self.on(*target))
else:
operations = [self.on(*target) for target in iterator]
return operations

if not __cirq_debug__.get():
return [
op
for q in targets
for op in (
self.on_each(*q)
if isinstance(q, Iterable) and not isinstance(q, str)
else [self.on(cast('cirq.Qid', q))]
)
]

for target in targets:
if isinstance(target, Qid):
operations.append(self.on(target))
Expand Down Expand Up @@ -617,7 +634,8 @@ def validate_args(self, qubits: Sequence['cirq.Qid']):
Raises:
ValueError: The operation had qids that don't match it's qid shape.
"""
_validate_qid_shape(self, qubits)
if __cirq_debug__.get():
_validate_qid_shape(self, qubits)

def _commutes_(
self, other: Any, *, atol: float = 1e-8
Expand Down
31 changes: 31 additions & 0 deletions cirq-core/cirq/ops/raw_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,29 @@ def test_op_validate():
op2.validate_args([cirq.LineQid(1, 2), cirq.LineQid(1, 2)])


def test_disable_op_validation():
q0, q1 = cirq.LineQubit.range(2)
h_op = cirq.H(q0)

# Fails normally.
with pytest.raises(ValueError, match='Wrong number'):
_ = cirq.H(q0, q1)
with pytest.raises(ValueError, match='Wrong number'):
h_op.validate_args([q0, q1])

# Passes, skipping validation.
with cirq.with_debug(False):
op = cirq.H(q0, q1)
assert op.qubits == (q0, q1)
h_op.validate_args([q0, q1])

# Fails again when validation is re-enabled.
with pytest.raises(ValueError, match='Wrong number'):
_ = cirq.H(q0, q1)
with pytest.raises(ValueError, match='Wrong number'):
h_op.validate_args([q0, q1])


def test_default_validation_and_inverse():
class TestGate(cirq.Gate):
def _num_qubits_(self):
Expand Down Expand Up @@ -787,6 +810,10 @@ def matrix(self):
test_non_qubits = [str(i) for i in range(3)]
with pytest.raises(ValueError):
_ = g.on_each(*test_non_qubits)

with cirq.with_debug(False):
assert g.on_each(*test_non_qubits)[0].qubits == ('0',)

with pytest.raises(ValueError):
_ = g.on_each(*test_non_qubits)

Expand Down Expand Up @@ -853,6 +880,10 @@ def test_on_each_two_qubits():
g.on_each([(a,)])
with pytest.raises(ValueError, match='Expected 2 qubits'):
g.on_each([(a, b, a)])

with cirq.with_debug(False):
assert g.on_each([(a, b, a)])[0].qubits == (a, b, a)

with pytest.raises(ValueError, match='Expected 2 qubits'):
g.on_each(zip([a, a]))
with pytest.raises(ValueError, match='Expected 2 qubits'):
Expand Down