Skip to content

cirq-core target gatesets: accept additional gates to keep untouched. #5445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/CZTargetGateset.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,43 @@
"cirq_type": "CZTargetGateset",
"atol": 1e-08,
"allow_partial_czs": true
},
{
"cirq_type": "CZTargetGateset",
"atol": 1e-06,
"allow_partial_czs": true,
"additional_gates": [
{
"cirq_type": "GateFamily",
"gate": {
"cirq_type": "ISwapPowGate",
"exponent": 0.5,
"global_shift": 0.0
},
"name": "Instance GateFamily: ISWAP**0.5",
"description": "Accepts `cirq.Gate` instances `g` s.t. `g == ISWAP**0.5`",
"ignore_global_phase": true,
"tags_to_accept": [],
"tags_to_ignore": []
},
{
"cirq_type": "GateFamily",
"gate": "XPowGate",
"name": "Type GateFamily: cirq.ops.common_gates.XPowGate",
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.XPowGate)`",
"ignore_global_phase": true,
"tags_to_accept": [],
"tags_to_ignore": []
},
{
"cirq_type": "GateFamily",
"gate": "ZPowGate",
"name": "Type GateFamily: cirq.ops.common_gates.ZPowGate",
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.ZPowGate)`",
"ignore_global_phase": true,
"tags_to_accept": [],
"tags_to_ignore": []
}
]
}
]
18 changes: 16 additions & 2 deletions cirq-core/cirq/protocols/json_test_data/CZTargetGateset.repr
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
[
cirq.CZTargetGateset(atol=1e-06, allow_partial_czs=False),
cirq.CZTargetGateset(atol=1e-08, allow_partial_czs=True),
cirq.CZTargetGateset(atol=1e-06, allow_partial_czs=False, additional_gates=[]),
cirq.CZTargetGateset(atol=1e-08, allow_partial_czs=True, additional_gates=[]),
cirq.CZTargetGateset(
atol=1e-06,
allow_partial_czs=True,
additional_gates=[
(cirq.ISWAP**0.5),
cirq.ops.common_gates.XPowGate,
cirq.GateFamily(
gate=cirq.ops.common_gates.ZPowGate,
ignore_global_phase=True,
tags_to_accept=frozenset(),
tags_to_ignore=frozenset(),
),
],
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,44 @@
"atol": 1e-06,
"required_sqrt_iswap_count": 2,
"use_sqrt_iswap_inv": true
},
{
"cirq_type": "SqrtIswapTargetGateset",
"atol": 1e-08,
"required_sqrt_iswap_count": null,
"use_sqrt_iswap_inv": false,
"additional_gates": [
{
"cirq_type": "GateFamily",
"gate": {
"cirq_type": "CZPowGate",
"exponent": 1.0,
"global_shift": 0.0
},
"name": "Instance GateFamily: CZ",
"description": "Accepts `cirq.Gate` instances `g` s.t. `g == CZ`",
"ignore_global_phase": true,
"tags_to_accept": [],
"tags_to_ignore": []
},
{
"cirq_type": "GateFamily",
"gate": "XPowGate",
"name": "Type GateFamily: cirq.ops.common_gates.XPowGate",
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.XPowGate)`",
"ignore_global_phase": true,
"tags_to_accept": [],
"tags_to_ignore": []
},
{
"cirq_type": "GateFamily",
"gate": "ZPowGate",
"name": "Type GateFamily: cirq.ops.common_gates.ZPowGate",
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.ZPowGate)`",
"ignore_global_phase": true,
"tags_to_accept": [],
"tags_to_ignore": []
}
]
}
]
Original file line number Diff line number Diff line change
@@ -1,7 +1,26 @@
[
cirq.SqrtIswapTargetGateset(
atol=1e-08, required_sqrt_iswap_count=None, use_sqrt_iswap_inv=False
atol=1e-08, required_sqrt_iswap_count=None, use_sqrt_iswap_inv=False, additional_gates=[]
),
cirq.SqrtIswapTargetGateset(
atol=1e-08, required_sqrt_iswap_count=1, use_sqrt_iswap_inv=False, additional_gates=[]
),
cirq.SqrtIswapTargetGateset(
atol=1e-06, required_sqrt_iswap_count=2, use_sqrt_iswap_inv=True, additional_gates=[]
),
cirq.SqrtIswapTargetGateset(
atol=1e-08,
required_sqrt_iswap_count=None,
use_sqrt_iswap_inv=False,
additional_gates=[
cirq.CZ,
cirq.ops.common_gates.XPowGate,
cirq.GateFamily(
gate=cirq.ops.common_gates.ZPowGate,
ignore_global_phase=True,
tags_to_accept=frozenset(),
tags_to_ignore=frozenset(),
),
],
),
cirq.SqrtIswapTargetGateset(atol=1e-08, required_sqrt_iswap_count=1, use_sqrt_iswap_inv=False),
cirq.SqrtIswapTargetGateset(atol=1e-06, required_sqrt_iswap_count=2, use_sqrt_iswap_inv=True),
]
59 changes: 50 additions & 9 deletions cirq-core/cirq/transformers/target_gatesets/cz_gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Target gateset used for compiling circuits to CZ + 1-q rotations + measurement gates."""

from typing import Any, Dict, TYPE_CHECKING
from typing import Any, Dict, Sequence, Type, Union, TYPE_CHECKING

from cirq import ops, protocols
from cirq.transformers.analytical_decompositions import two_qubit_to_cz
Expand All @@ -25,23 +25,53 @@


class CZTargetGateset(compilation_target_gateset.TwoQubitCompilationTargetGateset):
"""Target gateset containing CZ + single qubit rotations + Measurement gates."""
"""Target gateset accepting CZ + single qubit rotations + measurement gates.

def __init__(self, *, atol: float = 1e-8, allow_partial_czs: bool = False) -> None:
By default, `cirq.CZTargetGateset` will accept and compile unknown gates to
the following universal target gateset:
- `cirq.CZ` / `cirq.CZPowGate`: The two qubit entangling gate.
- `cirq.PhasedXZGate`: Single qubit rotations.
- `cirq.MeasurementGate`: Measurements.
- `cirq.GlobalPhaseGate`: Global phase.

Optionally, users can also specify additional gates / gate families which should
be accepted by this gateset via the `additional_gates` argument.

When compiling a circuit, any unknown gate, i.e. a gate which is not accepted by
this gateset, will be compiled to the default gateset (i.e. `cirq.CZ`/`cirq.CZPowGate`,
`cirq.PhasedXZGate`, `cirq.MeasurementGate`).
"""

def __init__(
self,
*,
atol: float = 1e-8,
allow_partial_czs: bool = False,
additional_gates: Sequence[Union[Type['cirq.Gate'], 'cirq.Gate', 'cirq.GateFamily']] = (),
) -> None:
"""Initializes CZTargetGateset

Args:
atol: A limit on the amount of absolute error introduced by the decomposition.
allow_partial_czs: If set, all powers of the form `cirq.CZ**t`, and not just
`cirq.CZ`, are part of this gateset.
additional_gates: Sequence of additional gates / gate families which should also
be "accepted" by this gateset. Defaults to `cirq.GlobalPhaseGate`.
"""
super().__init__(
ops.CZPowGate if allow_partial_czs else ops.CZ,
ops.MeasurementGate,
ops.AnyUnitaryGateFamily(1),
ops.PhasedXZGate,
ops.GlobalPhaseGate,
*additional_gates,
name='CZPowTargetGateset' if allow_partial_czs else 'CZTargetGateset',
)
self.additional_gates = tuple(
g if isinstance(g, ops.GateFamily) else ops.GateFamily(gate=g) for g in additional_gates
)
self._additional_gates_repr_str = ", ".join(
[ops.gateset._gate_str(g, repr) for g in additional_gates]
)
self.atol = atol
self.allow_partial_czs = allow_partial_czs

Expand All @@ -57,14 +87,25 @@ def _decompose_two_qubit_operation(self, op: 'cirq.Operation', _) -> 'cirq.OP_TR
)

def __repr__(self) -> str:
return f'cirq.CZTargetGateset(atol={self.atol}, allow_partial_czs={self.allow_partial_czs})'
return (
f'cirq.CZTargetGateset('
f'atol={self.atol}, '
f'allow_partial_czs={self.allow_partial_czs}, '
f'additional_gates=[{self._additional_gates_repr_str}]'
f')'
)

def _value_equality_values_(self) -> Any:
return self.atol, self.allow_partial_czs
return self.atol, self.allow_partial_czs, frozenset(self.additional_gates)

def _json_dict_(self) -> Dict[str, Any]:
return {'atol': self.atol, 'allow_partial_czs': self.allow_partial_czs}
d: Dict[str, Any] = {'atol': self.atol, 'allow_partial_czs': self.allow_partial_czs}
if self.additional_gates:
d['additional_gates'] = list(self.additional_gates)
return d

@classmethod
def _from_json_dict_(cls, atol, allow_partial_czs, **kwargs):
return cls(atol=atol, allow_partial_czs=allow_partial_czs)
def _from_json_dict_(cls, atol, allow_partial_czs, additional_gates=(), **kwargs):
return cls(
atol=atol, allow_partial_czs=allow_partial_czs, additional_gates=additional_gates
)
64 changes: 52 additions & 12 deletions cirq-core/cirq/transformers/target_gatesets/cz_gateset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Sequence, Type
import pytest
import cirq
import sympy
Expand All @@ -25,9 +26,18 @@ def all_gates_of_type(m: cirq.Moment, g: cirq.Gateset):
return True


def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit):
def assert_optimizes(
before: cirq.Circuit,
expected: cirq.Circuit,
additional_gates: Optional[Sequence[Type[cirq.Gate]]] = None,
):
if additional_gates is None:
gateset = cirq.CZTargetGateset()
else:
gateset = cirq.CZTargetGateset(additional_gates=additional_gates)

cirq.testing.assert_same_circuits(
cirq.optimize_for_target_gateset(before, gateset=cirq.CZTargetGateset()), expected
cirq.optimize_for_target_gateset(before, gateset=gateset, ignore_failures=False), expected
)


Expand All @@ -37,7 +47,7 @@ def assert_optimization_not_broken(circuit: cirq.Circuit):
circuit, c_new, atol=1e-6
)
c_new = cirq.optimize_for_target_gateset(
circuit, gateset=cirq.CZTargetGateset(allow_partial_czs=True)
circuit, gateset=cirq.CZTargetGateset(allow_partial_czs=True), ignore_failures=False
)
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
circuit, c_new, atol=1e-6
Expand All @@ -57,15 +67,19 @@ def test_convert_to_cz_preserving_moment_structure():
cirq.X(q[2]).with_classical_controls("m"),
cirq.CZ(*q[3:]).with_classical_controls("m"),
)
c_new = cirq.optimize_for_target_gateset(c_orig, gateset=cirq.CZTargetGateset())
# Classically controlled operations are not part of the gateset, so failures should be ignored
# during compilation.
c_new = cirq.optimize_for_target_gateset(
c_orig, gateset=cirq.CZTargetGateset(), ignore_failures=True
)

assert c_orig[-2:] == c_new[-2:]
c_orig, c_new = c_orig[:-2], c_new[:-2]

cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(c_orig, c_new, atol=1e-6)
assert all(
(
all_gates_of_type(m, cirq.Gateset(cirq.AnyUnitaryGateFamily(1)))
all_gates_of_type(m, cirq.Gateset(cirq.PhasedXZGate))
or all_gates_of_type(m, cirq.Gateset(cirq.CZ))
)
for m in c_new
Expand All @@ -77,7 +91,7 @@ def test_convert_to_cz_preserving_moment_structure():
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(c_orig, c_new, atol=1e-6)
assert all(
(
all_gates_of_type(m, cirq.Gateset(cirq.AnyUnitaryGateFamily(1)))
all_gates_of_type(m, cirq.Gateset(cirq.PhasedXZGate))
or all_gates_of_type(m, cirq.Gateset(cirq.CZPowGate))
)
for m in c_new
Expand Down Expand Up @@ -109,6 +123,7 @@ def test_ignores_czs_separated_by_parameterized():
cirq.Moment(cirq.CZ(a, b)),
]
),
additional_gates=[cirq.ZPowGate],
)


Expand Down Expand Up @@ -153,15 +168,15 @@ def test_optimizes_single_iswap():
a, b = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.ISWAP(a, b))
assert_optimization_not_broken(c)
c = cirq.optimize_for_target_gateset(c, gateset=cirq.CZTargetGateset())
c = cirq.optimize_for_target_gateset(c, gateset=cirq.CZTargetGateset(), ignore_failures=False)
assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 2


def test_optimizes_tagged_partial_cz():
a, b = cirq.LineQubit.range(2)
c = cirq.Circuit((cirq.CZ**0.5)(a, b).with_tags('mytag'))
assert_optimization_not_broken(c)
c = cirq.optimize_for_target_gateset(c, gateset=cirq.CZTargetGateset())
c = cirq.optimize_for_target_gateset(c, gateset=cirq.CZTargetGateset(), ignore_failures=False)
assert (
len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 2
), 'It should take 2 CZ gates to decompose a CZ**0.5 gate'
Expand All @@ -185,7 +200,9 @@ def test_not_decompose_czs():
),
)
def test_decompose_partial_czs(circuit):
circuit = cirq.optimize_for_target_gateset(circuit, gateset=cirq.CZTargetGateset())
circuit = cirq.optimize_for_target_gateset(
circuit, gateset=cirq.CZTargetGateset(), ignore_failures=False
)
cz_gates = [
op.gate
for op in circuit.all_operations()
Expand All @@ -201,7 +218,7 @@ def test_not_decompose_partial_czs():
circuit = cirq.Circuit(
cirq.CZPowGate(exponent=0.1, global_shift=-0.5)(*cirq.LineQubit.range(2))
)
cirq.optimize_for_target_gateset(circuit, gateset=cirq.CZTargetGateset())
cirq.optimize_for_target_gateset(circuit, gateset=cirq.CZTargetGateset(), ignore_failures=False)
cz_gates = [
op.gate
for op in circuit.all_operations()
Expand Down Expand Up @@ -240,7 +257,7 @@ def _decompose_(self, qubits):

a, b = cirq.LineQubit.range(2)
c = cirq.Circuit(OtherXX()(a, b), OtherOtherXX()(a, b))
c = cirq.optimize_for_target_gateset(c, gateset=cirq.CZTargetGateset())
c = cirq.optimize_for_target_gateset(c, gateset=cirq.CZTargetGateset(), ignore_failures=False)
assert len(c) == 0


Expand All @@ -260,7 +277,9 @@ def _decompose_(self, qubits):
expected = cirq.Circuit(
cirq.X(q0), cirq.Y(q0) ** 0.5, cirq.CZ(q0, q1), cirq.X(q1), cirq.Y(q1) ** 0.5
)
c_new = cirq.optimize_for_target_gateset(circuit, gateset=cirq.CZTargetGateset())
c_new = cirq.optimize_for_target_gateset(
circuit, gateset=cirq.CZTargetGateset(), ignore_failures=False
)

cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
c_new, expected, atol=1e-6
Expand All @@ -281,3 +300,24 @@ class UnsupportedDummy(cirq.testing.TwoQubitGate):
_ = cirq.optimize_for_target_gateset(
circuit, gateset=cirq.CZTargetGateset(), ignore_failures=False
)


@pytest.mark.parametrize(
'gateset',
[
cirq.CZTargetGateset(),
cirq.CZTargetGateset(
atol=1e-6,
allow_partial_czs=True,
additional_gates=[
cirq.SQRT_ISWAP,
cirq.XPowGate,
cirq.YPowGate,
cirq.GateFamily(cirq.ZPowGate, tags_to_accept=['test_tag']),
],
),
cirq.CZTargetGateset(additional_gates=()),
],
)
def test_repr(gateset):
cirq.testing.assert_equivalent_repr(gateset)
Loading