Skip to content

Commit e0ae7ca

Browse files
authored
Use GoogleCZTargetGateset; add device gateset as additional gates in target gatesets (#5765)
* GridDevice target gateset generation logic now matches all required gates for a particular gateset, rather than just the 2q gate. * In target gatesets that support additional_gates, it's set to all gates in a device's gateset other than required gates in the target gateset, so that device gates are not decomposed during circuit transformation. First commit is from #5744 @tanujkhattar cc @dstrain115
1 parent 27b5d4b commit e0ae7ca

File tree

5 files changed

+149
-59
lines changed

5 files changed

+149
-59
lines changed

cirq-core/cirq/transformers/target_gatesets/cz_gateset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
allow_partial_czs: If set, all powers of the form `cirq.CZ**t`, and not just
5757
`cirq.CZ`, are part of this gateset.
5858
additional_gates: Sequence of additional gates / gate families which should also
59-
be "accepted" by this gateset. Defaults to `cirq.GlobalPhaseGate`.
59+
be "accepted" by this gateset. This is empty by default.
6060
"""
6161
super().__init__(
6262
ops.CZPowGate if allow_partial_czs else ops.CZ,

cirq-core/cirq/transformers/target_gatesets/sqrt_iswap_gateset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(
6262
use_sqrt_iswap_inv: If True, `cirq.SQRT_ISWAP_INV` is used as part of the gateset,
6363
instead of `cirq.SQRT_ISWAP`.
6464
additional_gates: Sequence of additional gates / gate families which should also
65-
be "accepted" by this gateset. Defaults to `cirq.GlobalPhaseGate`.
65+
be "accepted" by this gateset. This is empty by default.
6666
6767
Raises:
6868
ValueError: If `required_sqrt_iswap_count` is specified and is not 0, 1, 2, or 3.

cirq-google/cirq_google/devices/grid_device.py

+87-51
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,39 @@
3939
from cirq_google.experimental import ops as experimental_ops
4040

4141

42-
SYC_GATE_FAMILY = cirq.GateFamily(ops.SYC)
43-
SQRT_ISWAP_GATE_FAMILY = cirq.GateFamily(cirq.SQRT_ISWAP)
44-
SQRT_ISWAP_INV_GATE_FAMILY = cirq.GateFamily(cirq.SQRT_ISWAP_INV)
45-
CZ_GATE_FAMILY = cirq.GateFamily(cirq.CZ)
46-
PHASED_XZ_GATE_FAMILY = cirq.GateFamily(cirq.PhasedXZGate)
47-
VIRTUAL_ZPOW_GATE_FAMILY = cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])
48-
PHYSICAL_ZPOW_GATE_FAMILY = cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])
49-
COUPLER_PULSE_GATE_FAMILY = cirq.GateFamily(experimental_ops.CouplerPulse)
50-
MEASUREMENT_GATE_FAMILY = cirq.GateFamily(cirq.MeasurementGate)
51-
WAIT_GATE_FAMILY = cirq.GateFamily(cirq.WaitGate)
42+
# Gate family constants used in various parts of GridDevice logic.
43+
_SYC_GATE_FAMILY = cirq.GateFamily(ops.SYC)
44+
_SQRT_ISWAP_GATE_FAMILY = cirq.GateFamily(cirq.SQRT_ISWAP)
45+
_SQRT_ISWAP_INV_GATE_FAMILY = cirq.GateFamily(cirq.SQRT_ISWAP_INV)
46+
_CZ_GATE_FAMILY = cirq.GateFamily(cirq.CZ)
47+
_PHASED_XZ_GATE_FAMILY = cirq.GateFamily(cirq.PhasedXZGate)
48+
_VIRTUAL_ZPOW_GATE_FAMILY = cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])
49+
_PHYSICAL_ZPOW_GATE_FAMILY = cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])
50+
_COUPLER_PULSE_GATE_FAMILY = cirq.GateFamily(experimental_ops.CouplerPulse)
51+
_MEASUREMENT_GATE_FAMILY = cirq.GateFamily(cirq.MeasurementGate)
52+
_WAIT_GATE_FAMILY = cirq.GateFamily(cirq.WaitGate)
53+
54+
_SYC_FSIM_GATE_FAMILY = ops.FSimGateFamily(gates_to_accept=[ops.SYC])
55+
_SQRT_ISWAP_FSIM_GATE_FAMILY = ops.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP])
56+
_SQRT_ISWAP_INV_FSIM_GATE_FAMILY = ops.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV])
57+
_CZ_FSIM_GATE_FAMILY = ops.FSimGateFamily(gates_to_accept=[cirq.CZ])
58+
59+
60+
# TODO(#5050) Add GlobalPhaseGate
61+
# Target gates of `cirq_google.GoogleCZTargetGateset`.
62+
_CZ_TARGET_GATES = [_CZ_FSIM_GATE_FAMILY, _PHASED_XZ_GATE_FAMILY, _MEASUREMENT_GATE_FAMILY]
63+
# Target gates of `cirq_google.SycamoreTargetGateset`.
64+
_SYC_TARGET_GATES = [_SYC_FSIM_GATE_FAMILY, _PHASED_XZ_GATE_FAMILY, _MEASUREMENT_GATE_FAMILY]
65+
# Target gates of `cirq.SqrtIswapTargetGateset`
66+
_SQRT_ISWAP_TARGET_GATES = [
67+
_SQRT_ISWAP_FSIM_GATE_FAMILY,
68+
_PHASED_XZ_GATE_FAMILY,
69+
_MEASUREMENT_GATE_FAMILY,
70+
]
71+
5272

5373
# Families of gates which can be applied to any subset of valid qubits.
54-
_VARIADIC_GATE_FAMILIES = [MEASUREMENT_GATE_FAMILY, WAIT_GATE_FAMILY]
74+
_VARIADIC_GATE_FAMILIES = [_MEASUREMENT_GATE_FAMILY, _WAIT_GATE_FAMILY]
5575

5676

5777
def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> None:
@@ -115,19 +135,19 @@ def _build_gateset_and_gate_durations(
115135
cirq_gates: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = []
116136

117137
if gate_name == 'syc':
118-
cirq_gates = [ops.FSimGateFamily(gates_to_accept=[ops.SYC])]
138+
cirq_gates = [_SYC_FSIM_GATE_FAMILY]
119139
elif gate_name == 'sqrt_iswap':
120-
cirq_gates = [ops.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP])]
140+
cirq_gates = [_SQRT_ISWAP_FSIM_GATE_FAMILY]
121141
elif gate_name == 'sqrt_iswap_inv':
122-
cirq_gates = [ops.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV])]
142+
cirq_gates = [_SQRT_ISWAP_INV_FSIM_GATE_FAMILY]
123143
elif gate_name == 'cz':
124-
cirq_gates = [ops.FSimGateFamily(gates_to_accept=[cirq.CZ])]
144+
cirq_gates = [_CZ_FSIM_GATE_FAMILY]
125145
elif gate_name == 'phased_xz':
126146
cirq_gates = [cirq.PhasedXZGate, cirq.XPowGate, cirq.YPowGate, cirq.PhasedXPowGate]
127147
elif gate_name == 'virtual_zpow':
128-
cirq_gates = [cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])]
148+
cirq_gates = [_VIRTUAL_ZPOW_GATE_FAMILY]
129149
elif gate_name == 'physical_zpow':
130-
cirq_gates = [cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])]
150+
cirq_gates = [_PHYSICAL_ZPOW_GATE_FAMILY]
131151
elif gate_name == 'coupler_pulse':
132152
cirq_gates = [experimental_ops.CouplerPulse]
133153
elif gate_name == 'meas':
@@ -161,28 +181,27 @@ def _build_gateset_and_gate_durations(
161181
def _build_compilation_target_gatesets(
162182
gateset: cirq.Gateset,
163183
) -> Sequence[cirq.CompilationTargetGateset]:
164-
"""Detects compilation target gatesets based on what gates are inside the gateset.
165-
166-
If a device contains gates which yield multiple compilation target gatesets, the user can only
167-
choose one target gateset to compile to. For example, a device may contain both SYC and
168-
SQRT_ISWAP gates which yield two separate target gatesets, but a circuit can only be compiled to
169-
either SYC or SQRT_ISWAP for its two-qubit gates, not both.
170-
171-
TODO(#5050) when cirq-google CompilationTargetGateset subclasses are implemented, mention that
172-
gates which are part of the gateset but not the compilation target gateset are untouched when
173-
compiled.
174-
"""
175-
176-
# TODO(#5050) Subclass core CompilationTargetGatesets in cirq-google.
184+
"""Detects compilation target gatesets based on what gates are inside the gateset."""
177185

186+
# Include a particular target gateset if the device's gateset contains all required gates of
187+
# the target gateset.
188+
# Set all remaining gates in the device's gateset as `additional_gates` so that they are not
189+
# decomposed in the transformation process.
178190
target_gatesets: List[cirq.CompilationTargetGateset] = []
179-
if cirq.CZ in gateset:
180-
target_gatesets.append(cirq.CZTargetGateset())
181-
if ops.SYC in gateset:
191+
if all(gate_family in gateset.gates for gate_family in _CZ_TARGET_GATES):
192+
target_gatesets.append(
193+
transformers.GoogleCZTargetGateset(
194+
additional_gates=list(gateset.gates - set(_CZ_TARGET_GATES))
195+
)
196+
)
197+
if all(gate_family in gateset.gates for gate_family in _SYC_TARGET_GATES):
198+
# TODO(#5050) SycamoreTargetGateset additional gates
182199
target_gatesets.append(transformers.SycamoreTargetGateset())
183-
if cirq.SQRT_ISWAP in gateset:
200+
if all(gate_family in gateset.gates for gate_family in _SQRT_ISWAP_TARGET_GATES):
184201
target_gatesets.append(
185-
cirq.SqrtIswapTargetGateset(use_sqrt_iswap_inv=cirq.SQRT_ISWAP_INV in gateset)
202+
cirq.SqrtIswapTargetGateset(
203+
additional_gates=list(gateset.gates - set(_SQRT_ISWAP_TARGET_GATES))
204+
)
186205
)
187206

188207
return tuple(target_gatesets)
@@ -265,7 +284,7 @@ class GridDevice(cirq.Device):
265284
transform a circuit to one which only contains gates from a native target gateset
266285
supported by the device.
267286
>>> device.metadata.compilation_target_gatesets
268-
(...cirq.CZTargetGateset...)
287+
(...cirq_google.GoogleCZTargetGateset...)
269288
270289
* Assuming valid CompilationTargetGatesets exist for the device, select the first one and
271290
use it to transform a circuit to one which only contains gates from a native target
@@ -277,11 +296,18 @@ class GridDevice(cirq.Device):
277296
>>> print(circuit)
278297
(5, 1): ───PhXZ(a=0,x=1,z=0)───
279298
280-
A note about CompilationTargetGatesets:
299+
Notes about CompilationTargetGatesets:
281300
282-
A circuit which contains `cirq.WaitGate`s will be dropped if it is transformed using
283-
CompilationTargetGatesets generated by GridDevice. To better control circuit timing, insert
284-
WaitGates after the circuit has been transformed.
301+
* If a device contains gates which yield multiple compilation target gatesets, the user can only
302+
choose one target gateset to compile to. For example, a device may contain both SYC and
303+
SQRT_ISWAP gates which yield two separate target gatesets, but a circuit can only be compiled
304+
to either SYC or SQRT_ISWAP for its two-qubit gates, not both.
305+
* For a given compilation target gateset, gates which are part of the device's gateset but not
306+
the target gateset are not decomposed. However, they may still be merged with other gates in
307+
the circuit.
308+
* A circuit which contains `cirq.WaitGate`s will be dropped if it is transformed using
309+
CompilationTargetGatesets generated by GridDevice. To better control circuit timing, insert
310+
WaitGates after the circuit has been transformed.
285311
286312
Notes for cirq_google internal implementation:
287313
@@ -435,31 +461,34 @@ def _value_equality_values_(self):
435461
def _set_gate_in_gate_spec(
436462
gate_spec: v2.device_pb2.GateSpecification, gate_family: cirq.GateFamily
437463
) -> None:
438-
if gate_family == SYC_GATE_FAMILY:
464+
if gate_family == _SYC_GATE_FAMILY or gate_family == _SYC_FSIM_GATE_FAMILY:
439465
gate_spec.syc.SetInParent()
440-
elif gate_family == SQRT_ISWAP_GATE_FAMILY:
466+
elif gate_family == _SQRT_ISWAP_GATE_FAMILY or gate_family == _SQRT_ISWAP_FSIM_GATE_FAMILY:
441467
gate_spec.sqrt_iswap.SetInParent()
442-
elif gate_family == SQRT_ISWAP_INV_GATE_FAMILY:
468+
elif (
469+
gate_family == _SQRT_ISWAP_INV_GATE_FAMILY
470+
or gate_family == _SQRT_ISWAP_INV_FSIM_GATE_FAMILY
471+
):
443472
gate_spec.sqrt_iswap_inv.SetInParent()
444-
elif gate_family == CZ_GATE_FAMILY:
473+
elif gate_family == _CZ_GATE_FAMILY or gate_family == _CZ_FSIM_GATE_FAMILY:
445474
gate_spec.cz.SetInParent()
446-
elif gate_family == PHASED_XZ_GATE_FAMILY:
475+
elif gate_family == _PHASED_XZ_GATE_FAMILY:
447476
gate_spec.phased_xz.SetInParent()
448-
elif gate_family == VIRTUAL_ZPOW_GATE_FAMILY:
477+
elif gate_family == _VIRTUAL_ZPOW_GATE_FAMILY:
449478
gate_spec.virtual_zpow.SetInParent()
450-
elif gate_family == PHYSICAL_ZPOW_GATE_FAMILY:
479+
elif gate_family == _PHYSICAL_ZPOW_GATE_FAMILY:
451480
gate_spec.physical_zpow.SetInParent()
452-
elif gate_family == COUPLER_PULSE_GATE_FAMILY:
481+
elif gate_family == _COUPLER_PULSE_GATE_FAMILY:
453482
gate_spec.coupler_pulse.SetInParent()
454-
elif gate_family == MEASUREMENT_GATE_FAMILY:
483+
elif gate_family == _MEASUREMENT_GATE_FAMILY:
455484
gate_spec.meas.SetInParent()
456-
elif gate_family == WAIT_GATE_FAMILY:
485+
elif gate_family == _WAIT_GATE_FAMILY:
457486
gate_spec.wait.SetInParent()
458487
else:
459488
raise ValueError(f'Unrecognized gate {gate_family}.')
460489

461490

462-
def create_device_specification_proto(
491+
def _create_device_specification_proto(
463492
*,
464493
qubits: Collection[cirq.GridQubit],
465494
pairs: Collection[Tuple[cirq.GridQubit, cirq.GridQubit]],
@@ -469,6 +498,13 @@ def create_device_specification_proto(
469498
) -> v2.device_pb2.DeviceSpecification:
470499
"""Serializes the given device information into a DeviceSpecification proto.
471500
501+
EXPERIMENTAL: DeviceSpecification serialization API may change.
502+
503+
This function does not serialize a `GridDevice`. Instead, it only takes a subset of device
504+
information sufficient to populate the `DeviceSpecification` proto. This reduces the complexity
505+
of constructing `DeviceSpecification` and `GridDevice` on server side by requiring only the bare
506+
essential device information.
507+
472508
Args:
473509
qubits: Collection of qubits available on the device.
474510
pairs: Collection of bidirectional qubit couplings available on the device.

cirq-google/cirq_google/devices/grid_device_test.py

+59-5
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,43 @@ def _create_device_spec_with_horizontal_couplings():
129129
}
130130

131131
expected_target_gatesets = (
132-
cirq.CZTargetGateset(),
132+
cirq_google.GoogleCZTargetGateset(
133+
additional_gates=[
134+
cirq_google.FSimGateFamily(gates_to_accept=[cirq_google.SYC]),
135+
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP]),
136+
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV]),
137+
cirq.ops.common_gates.XPowGate,
138+
cirq.ops.common_gates.YPowGate,
139+
cirq.ops.phased_x_gate.PhasedXPowGate,
140+
cirq.GateFamily(
141+
cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()]
142+
),
143+
cirq.GateFamily(
144+
cirq.ops.common_gates.ZPowGate, tags_to_accept=[cirq_google.PhysicalZTag()]
145+
),
146+
cirq_google.experimental.ops.coupler_pulse.CouplerPulse,
147+
cirq.ops.wait_gate.WaitGate,
148+
]
149+
),
133150
cirq_google.SycamoreTargetGateset(),
134-
cirq.SqrtIswapTargetGateset(use_sqrt_iswap_inv=True),
151+
cirq.SqrtIswapTargetGateset(
152+
additional_gates=[
153+
cirq_google.FSimGateFamily(gates_to_accept=[cirq_google.SYC]),
154+
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV]),
155+
cirq_google.FSimGateFamily(gates_to_accept=[cirq.CZ]),
156+
cirq.ops.common_gates.XPowGate,
157+
cirq.ops.common_gates.YPowGate,
158+
cirq.ops.phased_x_gate.PhasedXPowGate,
159+
cirq.GateFamily(
160+
cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()]
161+
),
162+
cirq.GateFamily(
163+
cirq.ops.common_gates.ZPowGate, tags_to_accept=[cirq_google.PhysicalZTag()]
164+
),
165+
cirq_google.experimental.ops.coupler_pulse.CouplerPulse,
166+
cirq.ops.wait_gate.WaitGate,
167+
]
168+
),
135169
)
136170

137171
return (
@@ -431,7 +465,7 @@ def test_to_proto():
431465
cirq.GateFamily(cirq.ops.wait_gate.WaitGate): base_duration * 9,
432466
}
433467

434-
spec = grid_device.create_device_specification_proto(
468+
spec = grid_device._create_device_specification_proto(
435469
qubits=device_info.grid_qubits,
436470
pairs=device_info.qubit_pairs,
437471
gateset=cirq.Gateset(*gate_durations.keys()),
@@ -471,13 +505,13 @@ def test_to_proto():
471505
)
472506
def test_to_proto_invalid_input(error_match, qubits, qubit_pairs, gateset, gate_durations):
473507
with pytest.raises(ValueError, match=error_match):
474-
grid_device.create_device_specification_proto(
508+
grid_device._create_device_specification_proto(
475509
qubits=qubits, pairs=qubit_pairs, gateset=gateset, gate_durations=gate_durations
476510
)
477511

478512

479513
def test_to_proto_empty():
480-
spec = grid_device.create_device_specification_proto(
514+
spec = grid_device._create_device_specification_proto(
481515
# Qubits are always expected to be set
482516
qubits=[cirq.GridQubit(0, i) for i in range(5)],
483517
pairs=[],
@@ -490,3 +524,23 @@ def test_to_proto_empty():
490524
assert len(device.metadata.qubit_pairs) == 0
491525
assert device.metadata.gateset == cirq.Gateset()
492526
assert device.metadata.gate_durations is None
527+
528+
529+
def test_to_proto_fsim_gate_family():
530+
"""Verifies that FSimGateFamilies are serialized correctly."""
531+
532+
gateset = cirq.Gateset(
533+
cirq_google.FSimGateFamily(gates_to_accept=[cirq_google.SYC]),
534+
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP]),
535+
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV]),
536+
cirq_google.FSimGateFamily(gates_to_accept=[cirq.CZ]),
537+
)
538+
539+
spec = grid_device._create_device_specification_proto(
540+
qubits=[cirq.GridQubit(0, 0)], pairs=(), gateset=gateset
541+
)
542+
543+
assert any(gate_spec.HasField('syc') for gate_spec in spec.valid_gates)
544+
assert any(gate_spec.HasField('sqrt_iswap') for gate_spec in spec.valid_gates)
545+
assert any(gate_spec.HasField('sqrt_iswap_inv') for gate_spec in spec.valid_gates)
546+
assert any(gate_spec.HasField('cz') for gate_spec in spec.valid_gates)

cirq-google/cirq_google/devices/known_devices.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _create_grid_device_from_diagram(
8080
if neighbor > qubit and neighbor in qubit_set:
8181
pairs.append((qubit, cast(cirq.GridQubit, neighbor)))
8282

83-
device_specification = grid_device.create_device_specification_proto(
83+
device_specification = grid_device._create_device_specification_proto(
8484
qubits=qubits, pairs=pairs, gateset=gateset, gate_durations=gate_durations, out=out
8585
)
8686
return grid_device.GridDevice.from_proto(device_specification)

0 commit comments

Comments
 (0)