Skip to content

Commit a46cdf0

Browse files
verultrht
authored andcommitted
cirq-core target gatesets: accept additional gates to keep untouched. (quantumlib#5445)
Builds on top of quantumlib#5429 The internal gate representation for `additional_gates` is updated to match `cirq.Gateset`: * Equality check uses GateFamily representation. Otherwise different representations of the gate will not be considered equal. * JSON uses GateFamily representation. * repr uses the representation passed in via the constructor. `assert_optimizes` in `cz_gateset_test.py` is updated to take in an optional `additional_gates` instead, to exercise CZTargetGateset constructor's defaulting logic. No tests are added since `additional_gates` need to be set in existing tests after `ignore_errors` is set to False. @tanujkhattar
1 parent cd2bfa0 commit a46cdf0

File tree

9 files changed

+314
-53
lines changed

9 files changed

+314
-53
lines changed

cirq-core/cirq/contrib/paulistring/optimize_test.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,15 @@ def test_optimize():
5050

5151
cirq.testing.assert_allclose_up_to_global_phase(c_orig.unitary(), c_opt.unitary(), atol=1e-6)
5252

53+
# TODO(#5546) Fix '[Z]^1' (should be 'Z')
5354
cirq.testing.assert_has_diagram(
5455
c_opt,
5556
"""
56-
0: ───X^0.5────────────@────────────────────────────────────────
57+
0: ───X^0.5────────────@──────────────────────────────────────────────
5758
58-
1: ───@───────X^-0.5───@───@────────────────@───Z^-0.5──────────
59+
1: ───@───────X^-0.5───@───@────────────────@───Z^-0.5────────────────
5960
│ │ │
60-
2: ───@────────────────────@───[X]^(-7/8)───@───[X]^-0.25───Z───
61+
2: ───@────────────────────@───[X]^(-7/8)───@───[X]^-0.25───[Z]^(1)───
6162
""",
6263
)
6364

cirq-core/cirq/protocols/json_test_data/CZTargetGateset.json

+38
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,43 @@
88
"cirq_type": "CZTargetGateset",
99
"atol": 1e-08,
1010
"allow_partial_czs": true
11+
},
12+
{
13+
"cirq_type": "CZTargetGateset",
14+
"atol": 1e-06,
15+
"allow_partial_czs": true,
16+
"additional_gates": [
17+
{
18+
"cirq_type": "GateFamily",
19+
"gate": {
20+
"cirq_type": "ISwapPowGate",
21+
"exponent": 0.5,
22+
"global_shift": 0.0
23+
},
24+
"name": "Instance GateFamily: ISWAP**0.5",
25+
"description": "Accepts `cirq.Gate` instances `g` s.t. `g == ISWAP**0.5`",
26+
"ignore_global_phase": true,
27+
"tags_to_accept": [],
28+
"tags_to_ignore": []
29+
},
30+
{
31+
"cirq_type": "GateFamily",
32+
"gate": "XPowGate",
33+
"name": "Type GateFamily: cirq.ops.common_gates.XPowGate",
34+
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.XPowGate)`",
35+
"ignore_global_phase": true,
36+
"tags_to_accept": [],
37+
"tags_to_ignore": []
38+
},
39+
{
40+
"cirq_type": "GateFamily",
41+
"gate": "ZPowGate",
42+
"name": "Type GateFamily: cirq.ops.common_gates.ZPowGate",
43+
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.ZPowGate)`",
44+
"ignore_global_phase": true,
45+
"tags_to_accept": [],
46+
"tags_to_ignore": []
47+
}
48+
]
1149
}
1250
]
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,18 @@
11
[
2-
cirq.CZTargetGateset(atol=1e-06, allow_partial_czs=False),
3-
cirq.CZTargetGateset(atol=1e-08, allow_partial_czs=True),
2+
cirq.CZTargetGateset(atol=1e-06, allow_partial_czs=False, additional_gates=[]),
3+
cirq.CZTargetGateset(atol=1e-08, allow_partial_czs=True, additional_gates=[]),
4+
cirq.CZTargetGateset(
5+
atol=1e-06,
6+
allow_partial_czs=True,
7+
additional_gates=[
8+
(cirq.ISWAP**0.5),
9+
cirq.ops.common_gates.XPowGate,
10+
cirq.GateFamily(
11+
gate=cirq.ops.common_gates.ZPowGate,
12+
ignore_global_phase=True,
13+
tags_to_accept=frozenset(),
14+
tags_to_ignore=frozenset(),
15+
),
16+
],
17+
),
418
]

cirq-core/cirq/protocols/json_test_data/SqrtIswapTargetGateset.json

+39
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,44 @@
1616
"atol": 1e-06,
1717
"required_sqrt_iswap_count": 2,
1818
"use_sqrt_iswap_inv": true
19+
},
20+
{
21+
"cirq_type": "SqrtIswapTargetGateset",
22+
"atol": 1e-08,
23+
"required_sqrt_iswap_count": null,
24+
"use_sqrt_iswap_inv": false,
25+
"additional_gates": [
26+
{
27+
"cirq_type": "GateFamily",
28+
"gate": {
29+
"cirq_type": "CZPowGate",
30+
"exponent": 1.0,
31+
"global_shift": 0.0
32+
},
33+
"name": "Instance GateFamily: CZ",
34+
"description": "Accepts `cirq.Gate` instances `g` s.t. `g == CZ`",
35+
"ignore_global_phase": true,
36+
"tags_to_accept": [],
37+
"tags_to_ignore": []
38+
},
39+
{
40+
"cirq_type": "GateFamily",
41+
"gate": "XPowGate",
42+
"name": "Type GateFamily: cirq.ops.common_gates.XPowGate",
43+
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.XPowGate)`",
44+
"ignore_global_phase": true,
45+
"tags_to_accept": [],
46+
"tags_to_ignore": []
47+
},
48+
{
49+
"cirq_type": "GateFamily",
50+
"gate": "ZPowGate",
51+
"name": "Type GateFamily: cirq.ops.common_gates.ZPowGate",
52+
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.ZPowGate)`",
53+
"ignore_global_phase": true,
54+
"tags_to_accept": [],
55+
"tags_to_ignore": []
56+
}
57+
]
1958
}
2059
]
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,26 @@
11
[
22
cirq.SqrtIswapTargetGateset(
3-
atol=1e-08, required_sqrt_iswap_count=None, use_sqrt_iswap_inv=False
3+
atol=1e-08, required_sqrt_iswap_count=None, use_sqrt_iswap_inv=False, additional_gates=[]
4+
),
5+
cirq.SqrtIswapTargetGateset(
6+
atol=1e-08, required_sqrt_iswap_count=1, use_sqrt_iswap_inv=False, additional_gates=[]
7+
),
8+
cirq.SqrtIswapTargetGateset(
9+
atol=1e-06, required_sqrt_iswap_count=2, use_sqrt_iswap_inv=True, additional_gates=[]
10+
),
11+
cirq.SqrtIswapTargetGateset(
12+
atol=1e-08,
13+
required_sqrt_iswap_count=None,
14+
use_sqrt_iswap_inv=False,
15+
additional_gates=[
16+
cirq.CZ,
17+
cirq.ops.common_gates.XPowGate,
18+
cirq.GateFamily(
19+
gate=cirq.ops.common_gates.ZPowGate,
20+
ignore_global_phase=True,
21+
tags_to_accept=frozenset(),
22+
tags_to_ignore=frozenset(),
23+
),
24+
],
425
),
5-
cirq.SqrtIswapTargetGateset(atol=1e-08, required_sqrt_iswap_count=1, use_sqrt_iswap_inv=False),
6-
cirq.SqrtIswapTargetGateset(atol=1e-06, required_sqrt_iswap_count=2, use_sqrt_iswap_inv=True),
726
]

cirq-core/cirq/transformers/target_gatesets/cz_gateset.py

+50-9
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Target gateset used for compiling circuits to CZ + 1-q rotations + measurement gates."""
1616

17-
from typing import Any, Dict, TYPE_CHECKING
17+
from typing import Any, Dict, Sequence, Type, Union, TYPE_CHECKING
1818

1919
from cirq import ops, protocols
2020
from cirq.transformers.analytical_decompositions import two_qubit_to_cz
@@ -25,23 +25,53 @@
2525

2626

2727
class CZTargetGateset(compilation_target_gateset.TwoQubitCompilationTargetGateset):
28-
"""Target gateset containing CZ + single qubit rotations + Measurement gates."""
28+
"""Target gateset accepting CZ + single qubit rotations + measurement gates.
2929
30-
def __init__(self, *, atol: float = 1e-8, allow_partial_czs: bool = False) -> None:
30+
By default, `cirq.CZTargetGateset` will accept and compile unknown gates to
31+
the following universal target gateset:
32+
- `cirq.CZ` / `cirq.CZPowGate`: The two qubit entangling gate.
33+
- `cirq.PhasedXZGate`: Single qubit rotations.
34+
- `cirq.MeasurementGate`: Measurements.
35+
- `cirq.GlobalPhaseGate`: Global phase.
36+
37+
Optionally, users can also specify additional gates / gate families which should
38+
be accepted by this gateset via the `additional_gates` argument.
39+
40+
When compiling a circuit, any unknown gate, i.e. a gate which is not accepted by
41+
this gateset, will be compiled to the default gateset (i.e. `cirq.CZ`/`cirq.CZPowGate`,
42+
`cirq.PhasedXZGate`, `cirq.MeasurementGate`).
43+
"""
44+
45+
def __init__(
46+
self,
47+
*,
48+
atol: float = 1e-8,
49+
allow_partial_czs: bool = False,
50+
additional_gates: Sequence[Union[Type['cirq.Gate'], 'cirq.Gate', 'cirq.GateFamily']] = (),
51+
) -> None:
3152
"""Initializes CZTargetGateset
3253
3354
Args:
3455
atol: A limit on the amount of absolute error introduced by the decomposition.
3556
allow_partial_czs: If set, all powers of the form `cirq.CZ**t`, and not just
3657
`cirq.CZ`, are part of this gateset.
58+
additional_gates: Sequence of additional gates / gate families which should also
59+
be "accepted" by this gateset. Defaults to `cirq.GlobalPhaseGate`.
3760
"""
3861
super().__init__(
3962
ops.CZPowGate if allow_partial_czs else ops.CZ,
4063
ops.MeasurementGate,
41-
ops.AnyUnitaryGateFamily(1),
64+
ops.PhasedXZGate,
4265
ops.GlobalPhaseGate,
66+
*additional_gates,
4367
name='CZPowTargetGateset' if allow_partial_czs else 'CZTargetGateset',
4468
)
69+
self.additional_gates = tuple(
70+
g if isinstance(g, ops.GateFamily) else ops.GateFamily(gate=g) for g in additional_gates
71+
)
72+
self._additional_gates_repr_str = ", ".join(
73+
[ops.gateset._gate_str(g, repr) for g in additional_gates]
74+
)
4575
self.atol = atol
4676
self.allow_partial_czs = allow_partial_czs
4777

@@ -57,14 +87,25 @@ def _decompose_two_qubit_operation(self, op: 'cirq.Operation', _) -> 'cirq.OP_TR
5787
)
5888

5989
def __repr__(self) -> str:
60-
return f'cirq.CZTargetGateset(atol={self.atol}, allow_partial_czs={self.allow_partial_czs})'
90+
return (
91+
f'cirq.CZTargetGateset('
92+
f'atol={self.atol}, '
93+
f'allow_partial_czs={self.allow_partial_czs}, '
94+
f'additional_gates=[{self._additional_gates_repr_str}]'
95+
f')'
96+
)
6197

6298
def _value_equality_values_(self) -> Any:
63-
return self.atol, self.allow_partial_czs
99+
return self.atol, self.allow_partial_czs, frozenset(self.additional_gates)
64100

65101
def _json_dict_(self) -> Dict[str, Any]:
66-
return {'atol': self.atol, 'allow_partial_czs': self.allow_partial_czs}
102+
d: Dict[str, Any] = {'atol': self.atol, 'allow_partial_czs': self.allow_partial_czs}
103+
if self.additional_gates:
104+
d['additional_gates'] = list(self.additional_gates)
105+
return d
67106

68107
@classmethod
69-
def _from_json_dict_(cls, atol, allow_partial_czs, **kwargs):
70-
return cls(atol=atol, allow_partial_czs=allow_partial_czs)
108+
def _from_json_dict_(cls, atol, allow_partial_czs, additional_gates=(), **kwargs):
109+
return cls(
110+
atol=atol, allow_partial_czs=allow_partial_czs, additional_gates=additional_gates
111+
)

0 commit comments

Comments
 (0)