13
13
# limitations under the License.
14
14
15
15
import itertools
16
- from typing import Callable , Sequence
16
+ from typing import Callable , Sequence , Tuple
17
17
18
18
import attr
19
19
import cirq
20
+ import numpy as np
20
21
from cirq ._compat import cached_property
21
22
from cirq_ft import infra
22
23
from cirq_ft .algos import unary_iteration_gate
@@ -36,8 +37,8 @@ class ApplyGateToLthQubit(unary_iteration_gate.UnaryIterationGate):
36
37
`selection`-th qubit of `target` all controlled by the `control` register.
37
38
38
39
Args:
39
- selection_regs: Indexing `select` registers of type `SelectionRegisters`. It also contains
40
- information about the iteration length of each selection register.
40
+ selection_regs: Indexing `select` registers of type Tuple[ `SelectionRegisters`, ...].
41
+ It also contains information about the iteration length of each selection register.
41
42
nth_gate: A function mapping the composite selection index to a single-qubit gate.
42
43
control_regs: Control registers for constructing a controlled version of the gate.
43
44
@@ -46,43 +47,45 @@ class ApplyGateToLthQubit(unary_iteration_gate.UnaryIterationGate):
46
47
(https://arxiv.org/abs/1805.03662).
47
48
Babbush et. al. (2018). Section III.A. and Figure 7.
48
49
"""
49
- selection_regs : infra .SelectionRegisters
50
+ selection_regs : Tuple [infra .SelectionRegister , ...] = attr .field (
51
+ converter = lambda v : (v ,) if isinstance (v , infra .SelectionRegister ) else tuple (v )
52
+ )
50
53
nth_gate : Callable [..., cirq .Gate ]
51
- control_regs : infra .Registers = infra .Registers .build (control = 1 )
54
+ control_regs : Tuple [infra .Register , ...] = attr .field (
55
+ converter = lambda v : (v ,) if isinstance (v , infra .Register ) else tuple (v ),
56
+ default = (infra .Register ('control' , 1 ),),
57
+ )
52
58
53
59
@classmethod
54
60
def make_on (
55
61
cls , * , nth_gate : Callable [..., cirq .Gate ], ** quregs : Sequence [cirq .Qid ]
56
62
) -> cirq .Operation :
57
63
"""Helper constructor to automatically deduce bitsize attributes."""
58
- return cls (
59
- infra .SelectionRegisters (
60
- [
61
- infra .SelectionRegister (
62
- 'selection' , len (quregs ['selection' ]), len (quregs ['target' ])
63
- )
64
- ]
65
- ),
64
+ return ApplyGateToLthQubit (
65
+ infra .SelectionRegister ('selection' , len (quregs ['selection' ]), len (quregs ['target' ])),
66
66
nth_gate = nth_gate ,
67
- control_regs = infra .Registers . build ( control = len (quregs ['control' ])),
67
+ control_regs = infra .Register ( ' control' , len (quregs ['control' ])),
68
68
).on_registers (** quregs )
69
69
70
70
@cached_property
71
- def control_registers (self ) -> infra .Registers :
71
+ def control_registers (self ) -> Tuple [ infra .Register , ...] :
72
72
return self .control_regs
73
73
74
74
@cached_property
75
- def selection_registers (self ) -> infra .SelectionRegisters :
75
+ def selection_registers (self ) -> Tuple [ infra .SelectionRegister , ...] :
76
76
return self .selection_regs
77
77
78
78
@cached_property
79
- def target_registers (self ) -> infra .Registers :
80
- return infra .Registers .build (target = self .selection_registers .total_iteration_size )
79
+ def target_registers (self ) -> Tuple [infra .Register , ...]:
80
+ total_iteration_size = np .product (
81
+ tuple (reg .iteration_length for reg in self .selection_registers )
82
+ )
83
+ return (infra .Register ('target' , int (total_iteration_size )),)
81
84
82
85
def _circuit_diagram_info_ (self , args : cirq .CircuitDiagramInfoArgs ) -> cirq .CircuitDiagramInfo :
83
- wire_symbols = ["@" ] * self . control_registers . total_bits ()
84
- wire_symbols += ["In" ] * self . selection_registers . total_bits ()
85
- for it in itertools .product (* [range (x ) for x in self .selection_regs . iteration_lengths ]):
86
+ wire_symbols = ["@" ] * infra . total_bits (self . control_registers )
87
+ wire_symbols += ["In" ] * infra . total_bits (self . selection_registers )
88
+ for it in itertools .product (* [range (reg . iteration_length ) for reg in self .selection_regs ]):
86
89
wire_symbols += [str (self .nth_gate (* it ))]
87
90
return cirq .CircuitDiagramInfo (wire_symbols = wire_symbols )
88
91
@@ -93,6 +96,7 @@ def nth_operation( # type: ignore[override]
93
96
target : Sequence [cirq .Qid ],
94
97
** selection_indices : int ,
95
98
) -> cirq .OP_TREE :
99
+ selection_shape = tuple (reg .iteration_length for reg in self .selection_regs )
96
100
selection_idx = tuple (selection_indices [reg .name ] for reg in self .selection_regs )
97
- target_idx = self . selection_registers . to_flat_idx ( * selection_idx )
101
+ target_idx = int ( np . ravel_multi_index ( selection_idx , selection_shape ) )
98
102
return self .nth_gate (* selection_idx ).on (target [target_idx ]).controlled_by (control )
0 commit comments