@@ -32,8 +32,9 @@ class Register:
32
32
"""
33
33
34
34
name : str
35
+ bitsize : int
35
36
shape : Tuple [int , ...] = attr .field (
36
- converter = lambda v : (v ,) if isinstance (v , int ) else tuple (v )
37
+ converter = lambda v : (v ,) if isinstance (v , int ) else tuple (v ), default = ()
37
38
)
38
39
39
40
def all_idxs (self ) -> Iterable [Tuple [int , ...]]:
@@ -45,15 +46,14 @@ def total_bits(self) -> int:
45
46
46
47
This is the product of each of the dimensions in `shape`.
47
48
"""
48
- return int (np .product (self .shape ))
49
+ return self . bitsize * int (np .product (self .shape ))
49
50
50
51
def __repr__ (self ):
51
- return f'cirq_ft.Register(name="{ self .name } ", shape={ self .shape } )'
52
+ return f'cirq_ft.Register(name="{ self .name } ", bitsize= { self . bitsize } , shape={ self .shape } )'
52
53
53
54
54
55
def total_bits (registers : Iterable [Register ]) -> int :
55
56
"""Sum of `reg.total_bits()` for each register `reg` in input `registers`."""
56
-
57
57
return sum (reg .total_bits () for reg in registers )
58
58
59
59
@@ -65,7 +65,9 @@ def split_qubits(
65
65
qubit_regs = {}
66
66
base = 0
67
67
for reg in registers :
68
- qubit_regs [reg .name ] = np .array (qubits [base : base + reg .total_bits ()]).reshape (reg .shape )
68
+ qubit_regs [reg .name ] = np .array (qubits [base : base + reg .total_bits ()]).reshape (
69
+ reg .shape + (reg .bitsize ,)
70
+ )
69
71
base += reg .total_bits ()
70
72
return qubit_regs
71
73
@@ -82,9 +84,10 @@ def merge_qubits(
82
84
raise ValueError (f"All qubit registers must be present. { reg .name } not in qubit_regs" )
83
85
qubits = qubit_regs [reg .name ]
84
86
qubits = np .array ([qubits ] if isinstance (qubits , cirq .Qid ) else qubits )
85
- if qubits .shape != reg .shape :
87
+ full_shape = reg .shape + (reg .bitsize ,)
88
+ if qubits .shape != full_shape :
86
89
raise ValueError (
87
- f'{ reg .name } register must of shape { reg . shape } but is of shape { qubits .shape } '
90
+ f'{ reg .name } register must of shape { full_shape } but is of shape { qubits .shape } '
88
91
)
89
92
ret += qubits .flatten ().tolist ()
90
93
return ret
@@ -94,13 +97,16 @@ def get_named_qubits(registers: Iterable[Register]) -> Dict[str, NDArray[cirq.Qi
94
97
"""Returns a dictionary of appropriately shaped named qubit registers for input `registers`."""
95
98
96
99
def _qubit_array (reg : Register ):
97
- qubits = np .empty (reg .shape , dtype = object )
100
+ qubits = np .empty (reg .shape + ( reg . bitsize ,) , dtype = object )
98
101
for ii in reg .all_idxs ():
99
- qubits [ii ] = cirq .NamedQubit (f'{ reg .name } [{ ", " .join (str (i ) for i in ii )} ]' )
102
+ for j in range (reg .bitsize ):
103
+ prefix = "" if not ii else f'[{ ", " .join (str (i ) for i in ii )} ]'
104
+ suffix = "" if reg .bitsize == 1 else f"[{ j } ]"
105
+ qubits [ii + (j ,)] = cirq .NamedQubit (reg .name + prefix + suffix )
100
106
return qubits
101
107
102
108
def _qubits_for_reg (reg : Register ):
103
- if len (reg .shape ) > 1 :
109
+ if len (reg .shape ) > 0 :
104
110
return _qubit_array (reg )
105
111
106
112
return np .array (
@@ -130,8 +136,8 @@ def __repr__(self):
130
136
return f'cirq_ft.Registers({ self ._registers } )'
131
137
132
138
@classmethod
133
- def build (cls , ** registers : Union [ int , Tuple [ int , ...]] ) -> 'Registers' :
134
- return cls (Register (name = k , shape = v ) for k , v in registers .items ())
139
+ def build (cls , ** registers : int ) -> 'Registers' :
140
+ return cls (Register (name = k , bitsize = v ) for k , v in registers .items ())
135
141
136
142
@overload
137
143
def __getitem__ (self , key : int ) -> Register :
@@ -216,23 +222,29 @@ class SelectionRegister(Register):
216
222
>>> assert len(flat_indices) == N * M * L
217
223
"""
218
224
225
+ name : str
226
+ bitsize : int
219
227
iteration_length : int = attr .field ()
228
+ shape : Tuple [int , ...] = attr .field (
229
+ converter = lambda v : (v ,) if isinstance (v , int ) else tuple (v ), default = ()
230
+ )
220
231
221
232
@iteration_length .default
222
233
def _default_iteration_length (self ):
223
- return 2 ** self .shape [ 0 ]
234
+ return 2 ** self .bitsize
224
235
225
236
@iteration_length .validator
226
237
def validate_iteration_length (self , attribute , value ):
227
- if len (self .shape ) != 1 :
238
+ if len (self .shape ) != 0 :
228
239
raise ValueError (f'Selection register { self .name } should be flat. Found { self .shape = } ' )
229
- if not (0 <= value <= 2 ** self .shape [ 0 ] ):
230
- raise ValueError (f'iteration length must be in range [0, 2^{ self .shape [ 0 ] } ]' )
240
+ if not (0 <= value <= 2 ** self .bitsize ):
241
+ raise ValueError (f'iteration length must be in range [0, 2^{ self .bitsize } ]' )
231
242
232
243
def __repr__ (self ) -> str :
233
244
return (
234
245
f'cirq_ft.SelectionRegister('
235
246
f'name="{ self .name } ", '
247
+ f'bitsize={ self .bitsize } , '
236
248
f'shape={ self .shape } , '
237
249
f'iteration_length={ self .iteration_length } )'
238
250
)
0 commit comments