Skip to content

Repetition IDs for CircuitOperation #3741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Feb 8, 2021
155 changes: 139 additions & 16 deletions cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
component operations in order, including any nested CircuitOperations.
"""

from typing import TYPE_CHECKING, AbstractSet, Callable, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, AbstractSet, Callable, Dict, List, Optional, Tuple, Union

import dataclasses
import numpy as np
Expand All @@ -33,6 +33,22 @@
INT_TYPE = Union[int, np.integer]


def default_repetition_ids(repetitions: int) -> Optional[List[str]]:
if abs(repetitions) > 1:
return [str(i) for i in range(abs(repetitions))]
return None


def cartesian_product_of_string_lists(list1: Optional[List[str]], list2: Optional[List[str]]):
if list1 is None and list2 is None:
return None # coverage: ignore
if list1 is None:
return list2 # coverage: ignore
if list2 is None:
return list1
return [f'{first}-{second}' for first in list1 for second in list2]


@dataclasses.dataclass(frozen=True)
class CircuitOperation(ops.Operation):
"""An operation that encapsulates a circuit.
Expand All @@ -47,6 +63,10 @@ class CircuitOperation(ops.Operation):
qubit_map: Remappings for qubits in the circuit.
measurement_key_map: Remappings for measurement keys in the circuit.
param_resolver: Resolved values for parameters in the circuit.
repetition_ids: List of identifiers for each repetition of the
CircuitOperation. If populated, the length should be equal to the
repetitions. If not populated and abs(`repetitions`) > 1, it is
initialized to strings for numbers in `range(repetitions)`.
"""

_hash: Optional[int] = dataclasses.field(default=None, init=False)
Expand All @@ -56,10 +76,28 @@ class CircuitOperation(ops.Operation):
qubit_map: Dict['cirq.Qid', 'cirq.Qid'] = dataclasses.field(default_factory=dict)
measurement_key_map: Dict[str, str] = dataclasses.field(default_factory=dict)
param_resolver: study.ParamResolver = study.ParamResolver()
repetition_ids: Optional[List[str]] = dataclasses.field(default=None)

def __post_init__(self):
if not isinstance(self.circuit, circuits.FrozenCircuit):
raise TypeError(f'Expected circuit of type FrozenCircuit, got: {type(self.circuit)!r}')

# Ensure that the circuit is invertible if the repetitions are negative.
if self.repetitions < 0:
try:
protocols.inverse(self.circuit.unfreeze())
except TypeError:
raise ValueError(f'repetitions are negative but the circuit is not invertible')

# Initialize repetition_ids to default, if unspecified. Else, validate their length.
loop_size = abs(self.repetitions)
if not self.repetition_ids:
object.__setattr__(self, 'repetition_ids', self._default_repetition_ids())
elif len(self.repetition_ids) != loop_size:
raise ValueError(
f'Expected repetition_ids to be a list of length {loop_size}, '
f'got: {self.repetition_ids}'
)
# Ensure that param_resolver is converted to an actual ParamResolver.
object.__setattr__(self, 'param_resolver', study.ParamResolver(self.param_resolver))

Expand All @@ -83,6 +121,7 @@ def __eq__(self, other) -> bool:
and self.measurement_key_map == other.measurement_key_map
and self.param_resolver == other.param_resolver
and self.repetitions == other.repetitions
and self.repetition_ids == other.repetition_ids
)

# Methods for getting post-mapping properties of the contained circuit.
Expand All @@ -93,6 +132,9 @@ def qubits(self) -> Tuple['cirq.Qid', ...]:
ordered_qubits = ops.QubitOrder.DEFAULT.order_for(self.circuit.all_qubits())
return tuple(self.qubit_map.get(q, q) for q in ordered_qubits)

def _default_repetition_ids(self) -> Optional[List[str]]:
return default_repetition_ids(self.repetitions)

def _qid_shape_(self) -> Tuple[int, ...]:
return tuple(q.dimension for q in self.qubits)

Expand All @@ -117,8 +159,39 @@ def _decompose_(self) -> 'cirq.OP_TREE':
result = result ** -1
result = protocols.with_measurement_key_mapping(result, self.measurement_key_map)
result = protocols.resolve_parameters(result, self.param_resolver, recursive=False)

return list(result.all_operations()) * abs(self.repetitions)
# repetition_ids don't need to be taken into account if the circuit has no measurements
# or if repetition_ids are unset.
if self.repetition_ids is None or not protocols.is_measurement(result):
return list(result.all_operations()) * abs(self.repetitions)
# If it's a measurement circuit with repetitions/repetition_ids, prefix the repetition_ids
# to measurements. Details at https://tinyurl.com/measurement-repeated-circuitop.
ops = [] # type: List[cirq.Operation]
for parent_id in self.repetition_ids:
for op in result.all_operations():
if isinstance(op, CircuitOperation):
# For a CircuitOperation, prefix the current repetition_id to the children
# repetition_ids.
ops.append(
op.with_repetition_ids(
# If `op.repetition_ids` is None, this will return `[parent_id]`.
cartesian_product_of_string_lists([parent_id], op.repetition_ids)
)
)
elif protocols.is_measurement(op):
# For a non-CircuitOperation measurement, prefix the current repetition_id
# to the children measurement keys. Implemented by creating a mapping and
# using the with_measurement_key_mapping protocol.
ops.append(
protocols.with_measurement_key_mapping(
op,
key_map={
key: f'{parent_id}-{key}' for key in protocols.measurement_keys(op)
},
)
)
else:
ops.append(op)
return ops

# Methods for string representation of the operation.

Expand All @@ -132,6 +205,9 @@ def __repr__(self):
args += f'measurement_key_map={proper_repr(self.measurement_key_map)},\n'
if self.param_resolver:
args += f'param_resolver={proper_repr(self.param_resolver)},\n'
if self.repetition_ids != self._default_repetition_ids():
# Default repetition_ids need not be specified.
args += f'repetition_ids={proper_repr(self.repetition_ids)},\n'
indented_args = args.replace('\n', '\n ')
return f'cirq.CircuitOperation({indented_args[:-4]})'

Expand All @@ -155,7 +231,11 @@ def dict_str(d: Dict) -> str:
args.append(f'key_map={dict_str(self.measurement_key_map)}')
if self.param_resolver:
args.append(f'params={self.param_resolver.param_dict}')
if self.repetitions != 1:
if self.repetition_ids != self._default_repetition_ids():
# Default repetition_ids need not be specified.
args.append(f'repetition_ids={self.repetition_ids}')
elif self.repetitions != 1:
# Only add loops if we haven't added repetition_ids.
args.append(f'loops={self.repetitions}')
if not args:
return f'{header}\n{circuit_msg}'
Expand All @@ -173,6 +253,7 @@ def __hash__(self):
frozenset(self.qubit_map.items()),
frozenset(self.measurement_key_map.items()),
self.param_resolver,
tuple([] if self.repetition_ids is None else self.repetition_ids),
)
),
)
Expand All @@ -188,53 +269,95 @@ def _json_dict_(self):
'qubit_map': sorted(self.qubit_map.items()),
'measurement_key_map': self.measurement_key_map,
'param_resolver': self.param_resolver,
'repetition_ids': self.repetition_ids,
}

@classmethod
def _from_json_dict_(
cls, circuit, repetitions, qubit_map, measurement_key_map, param_resolver, **kwargs
cls,
circuit,
repetitions,
qubit_map,
measurement_key_map,
param_resolver,
repetition_ids,
**kwargs,
):
return (
cls(circuit)
.with_qubit_mapping(dict(qubit_map))
.with_measurement_key_mapping(measurement_key_map)
.with_params(param_resolver)
.repeat(repetitions)
.repeat(repetitions, repetition_ids)
)

# Methods for constructing a similar object with one field modified.

def repeat(
self,
repetitions: INT_TYPE,
repetitions: Optional[INT_TYPE] = None,
repetition_ids: Optional[List[str]] = None,
) -> 'CircuitOperation':
"""Returns a copy of this operation repeated 'repetitions' times.
Each repetition instance will be identified by a single repetition_id.

Args:
repetitions: Number of times this operation should repeat. This
is multiplied with any pre-existing repetitions.
is multiplied with any pre-existing repetitions. If unset, it
defaults to the length of `repetition_ids`.
repetition_ids: List of IDs, one for each repetition. If unset,
defaults to `default_repetition_ids(repetitions)`.

Returns:
A copy of this operation repeated 'repetitions' times.
A copy of this operation repeated `repetitions` times with the
appropriate `repetition_ids`. The output `repetition_ids` are the
cartesian product of input `repetition_ids` with the base
operation's `repetition_ids`. If the base operation has unset
`repetition_ids` (indicates {-1, 0, 1} `repetitions` with no custom
IDs), the input `repetition_ids` are directly used.

Raises:
TypeError: `repetitions` is not an integer value.
NotImplementedError: The operation contains measurements and
cannot have repetitions.
ValueError: Unexpected length of `repetition_ids`.
ValueError: Both `repetitions` and `repetition_ids` are None.
"""
if repetitions is None:
if repetition_ids is None:
raise ValueError('At least one of repetitions and repetition_ids must be set')
repetitions = len(repetition_ids)

if not isinstance(repetitions, (int, np.integer)):
raise TypeError('Only integer repetitions are allowed.')
if repetitions == 1:

repetitions = int(repetitions)

if repetitions == 1 and repetition_ids is None:
# As CircuitOperation is immutable, this can safely return the original.
return self
repetitions = int(repetitions)
if protocols.is_measurement(self.circuit):
raise NotImplementedError('Loops over measurements are not supported.')
return self.replace(repetitions=self.repetitions * repetitions)

expected_repetition_id_length = abs(repetitions)
# The eventual number of repetitions of the returned CircuitOperation.
final_repetitions = self.repetitions * repetitions

if repetition_ids is None:
repetition_ids = default_repetition_ids(expected_repetition_id_length)
elif len(repetition_ids) != expected_repetition_id_length:
raise ValueError(
f'Expected repetition_ids={repetition_ids} length to be '
f'{expected_repetition_id_length}'
)

# If `self.repetition_ids` is None, this will just return `repetition_ids`.
repetition_ids = cartesian_product_of_string_lists(repetition_ids, self.repetition_ids)

return self.replace(repetitions=final_repetitions, repetition_ids=repetition_ids)

def __pow__(self, power: int) -> 'CircuitOperation':
return self.repeat(power)

def with_repetition_ids(self, repetition_ids: List[str]) -> 'CircuitOperation':
return self.replace(repetition_ids=repetition_ids)

def with_qubit_mapping(
self,
qubit_map: Union[Dict['cirq.Qid', 'cirq.Qid'], Callable[['cirq.Qid'], 'cirq.Qid']],
Expand Down
Loading