-
Notifications
You must be signed in to change notification settings - Fork 1.1k
(De-)serialization of CircuitOperations #3923
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
185e95e
c80d982
1ee4b30
fdaa216
4aa8ccf
14ca867
de2b042
fded376
7f3a269
1877a71
8b2ed14
b98219d
c3f5215
69b7168
138441b
f6020d4
4759ff2
c717ef9
0a90658
a37dccb
b7ab8d6
c40d061
233fe88
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,11 @@ | |
) | ||
from dataclasses import dataclass | ||
|
||
import abc | ||
import sympy | ||
|
||
from cirq import circuits | ||
from cirq._compat import deprecated | ||
from cirq.google import arg_func_langs | ||
from cirq.google.api import v2 | ||
from cirq.google.ops.calibration_tag import CalibrationTag | ||
|
@@ -31,6 +36,34 @@ | |
import cirq | ||
|
||
|
||
class OpDeserializer(abc.ABC): | ||
"""Generic supertype for op deserializers.""" | ||
|
||
@property | ||
@abc.abstractmethod | ||
def serialized_id(self) -> str: | ||
"""Returns the string identifier for the resulting serialized object. | ||
|
||
This value should reflect the internal_type of the serializer. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure I understand what this sentence means. What does "should reflect the internal_type" imply? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I must have misunderstood the IDs when typing this up (plus a copy-paste error in the deserializer comment). Updated to (hopefully) more accurately represent the behavior. |
||
""" | ||
|
||
@property # type: ignore | ||
@deprecated(deadline='v0.12', fix='Use serialized_id instead.') | ||
def serialized_gate_id(self) -> str: | ||
return self.serialized_id | ||
|
||
@abc.abstractmethod | ||
def from_proto( | ||
self, | ||
proto, | ||
*, | ||
arg_function_language: str = '', | ||
constants: List[v2.program_pb2.Constant] = None, | ||
raw_constants: List[Any] = None, | ||
95-martin-orion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> 'cirq.Operation': | ||
pass | ||
|
||
|
||
@dataclass(frozen=True) | ||
class DeserializingArg: | ||
"""Specification of the arguments to deserialize an argument to a gate. | ||
|
@@ -57,7 +90,7 @@ class DeserializingArg: | |
default: Any = None | ||
|
||
|
||
class GateOpDeserializer: | ||
class GateOpDeserializer(OpDeserializer): | ||
"""Describes how to deserialize a proto to a given Gate type. | ||
|
||
Attributes: | ||
|
@@ -94,28 +127,33 @@ def __init__( | |
deserialize_tokens: Whether to convert tokens to | ||
CalibrationTags. Defaults to True. | ||
""" | ||
self.serialized_gate_id = serialized_gate_id | ||
self.gate_constructor = gate_constructor | ||
self.args = args | ||
self.num_qubits_param = num_qubits_param | ||
self.op_wrapper = op_wrapper | ||
self.deserialize_tokens = deserialize_tokens | ||
self._serialized_gate_id = serialized_gate_id | ||
self._gate_constructor = gate_constructor | ||
self._args = args | ||
self._num_qubits_param = num_qubits_param | ||
self._op_wrapper = op_wrapper | ||
self._deserialize_tokens = deserialize_tokens | ||
|
||
@property | ||
def serialized_id(self): | ||
return self._serialized_gate_id | ||
|
||
def from_proto( | ||
self, | ||
proto: v2.program_pb2.Operation, | ||
*, | ||
arg_function_language: str = '', | ||
constants: List[v2.program_pb2.Constant] = None, | ||
raw_constants: List[Any] = None, # unused | ||
) -> 'cirq.Operation': | ||
"""Turns a cirq.google.api.v2.Operation proto into a GateOperation.""" | ||
qubits = [v2.qubit_from_proto_id(q.id) for q in proto.qubits] | ||
args = self._args_from_proto(proto, arg_function_language=arg_function_language) | ||
if self.num_qubits_param is not None: | ||
args[self.num_qubits_param] = len(qubits) | ||
gate = self.gate_constructor(**args) | ||
op = self.op_wrapper(gate.on(*qubits), proto) | ||
if self.deserialize_tokens: | ||
if self._num_qubits_param is not None: | ||
args[self._num_qubits_param] = len(qubits) | ||
gate = self._gate_constructor(**args) | ||
op = self._op_wrapper(gate.on(*qubits), proto) | ||
if self._deserialize_tokens: | ||
which = proto.WhichOneof('token') | ||
if which == 'token_constant_index': | ||
if not constants: | ||
|
@@ -135,7 +173,7 @@ def _args_from_proto( | |
self, proto: v2.program_pb2.Operation, *, arg_function_language: str | ||
) -> Dict[str, arg_func_langs.ARG_LIKE]: | ||
return_args = {} | ||
for arg in self.args: | ||
for arg in self._args: | ||
if arg.serialized_name not in proto.args: | ||
if arg.default: | ||
return_args[arg.constructor_arg_name] = arg.default | ||
|
@@ -158,3 +196,86 @@ def _args_from_proto( | |
if value is not None: | ||
return_args[arg.constructor_arg_name] = value | ||
return return_args | ||
|
||
|
||
class CircuitOpDeserializer(OpDeserializer): | ||
"""Describes how to serialize CircuitOperations.""" | ||
|
||
@property | ||
def serialized_id(self): | ||
return 'circuit' | ||
|
||
def from_proto( | ||
self, | ||
proto: v2.program_pb2.CircuitOperation, | ||
*, | ||
arg_function_language: str = '', | ||
constants: List[v2.program_pb2.Constant] = None, | ||
raw_constants: List[Any] = None, | ||
) -> 'cirq.CircuitOperation': | ||
"""Turns a cirq.google.api.v2.CircuitOperation proto into a CircuitOperation.""" | ||
if constants is None or raw_constants is None: | ||
95-martin-orion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError( | ||
'CircuitOp deserialization requires a constants list and a corresponding list of ' | ||
'post-deserialization values (raw_constants).' | ||
) | ||
circuit = raw_constants[proto.circuit_constant_index] | ||
if not isinstance(circuit, circuits.FrozenCircuit): | ||
raise ValueError( | ||
f'Constant at index {proto.circuit_constant_index} was expected to be a circuit, ' | ||
f'but it has type {type(circuit)} in the raw_constants list.' | ||
) | ||
|
||
which_rep_spec = proto.repetition_specification.WhichOneof('repetition_value') | ||
if which_rep_spec == "repetition_count": | ||
95-martin-orion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
rep_ids = None | ||
repetitions = proto.repetition_specification.repetition_count | ||
elif which_rep_spec == "repetition_ids": | ||
rep_ids = proto.repetition_specification.repetition_ids.ids | ||
repetitions = len(rep_ids) | ||
else: | ||
rep_ids = None | ||
repetitions = 1 | ||
|
||
qubit_map = { | ||
v2.qubit_from_proto_id(entry.key.id): v2.qubit_from_proto_id(entry.value.id) | ||
for entry in proto.qubit_map.entries | ||
} | ||
measurement_key_map = { | ||
entry.key.string_key: entry.value.string_key | ||
for entry in proto.measurement_key_map.entries | ||
} | ||
arg_map = { | ||
arg_func_langs.arg_from_proto( | ||
entry.key, arg_function_language=arg_function_language | ||
): arg_func_langs.arg_from_proto( | ||
entry.value, arg_function_language=arg_function_language | ||
) | ||
for entry in proto.arg_map.entries | ||
} | ||
|
||
for arg in arg_map.keys(): | ||
if not isinstance(arg, (str, sympy.Symbol)): | ||
print('whoopee') | ||
95-martin-orion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError( | ||
'Invalid key parameter type in deserialized CircuitOperation. ' | ||
f'Expected str or sympy.Symbol, found {type(arg)}.' | ||
f'\nFull arg: {arg}' | ||
) | ||
|
||
for arg in arg_map.values(): | ||
if not isinstance(arg, (str, sympy.Symbol, float, int)): | ||
raise ValueError( | ||
'Invalid value parameter type in deserialized CircuitOperation. ' | ||
f'Expected str, sympy.Symbol, or number; found {type(arg)}.' | ||
f'\nFull arg: {arg}' | ||
) | ||
|
||
return circuits.CircuitOperation( | ||
circuit, | ||
repetitions, | ||
qubit_map, | ||
measurement_key_map, | ||
arg_map, # type: ignore | ||
rep_ids, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing that we should think about is whether a device specification should specify that the device supports circuit operations or not. (Or should we assume they all support circuit operations?)
This may have to be in a later PR, since it involves further proto changes, but we should give it some thought and maybe add an issue about it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apart from the transitional phase where we're still bringing support online, I think assuming all devices support circuit operations is valid, since in the worst case you can simply decompose the circuit operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO(95-martin-orion): test creating a device spec with CircuitOperation to identify necessary changes / ensure that this is well-behaved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dstrain115: added
test_sycamore_circuitop_device
to provide this coverage, and it failed. Updatedserializable_device
to resolve the issue.