Skip to content

Commit 8e4e7d1

Browse files
authored
Add bitsize field to Cirq-FT Registers (#6286)
1 parent 188bb94 commit 8e4e7d1

File tree

4 files changed

+43
-26
lines changed

4 files changed

+43
-26
lines changed

cirq-ft/cirq_ft/algos/swap_network.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,9 @@ def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]:
152152

153153
@cached_property
154154
def target_registers(self) -> Tuple[infra.Register, ...]:
155-
return (infra.Register('target', (self.n_target_registers, self.target_bitsize)),)
155+
return (
156+
infra.Register('target', bitsize=self.target_bitsize, shape=self.n_target_registers),
157+
)
156158

157159
@cached_property
158160
def registers(self) -> infra.Registers:

cirq-ft/cirq_ft/infra/gate_with_registers.ipynb

+3-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
"source": [
4040
"## `Registers`\n",
4141
"\n",
42-
"`Register` objects have a name and a shape. `Registers` is an ordered collection of `Register` with some helpful methods."
42+
"`Register` objects have a name, a bitsize and a shape. `Registers` is an ordered collection of `Register` with some helpful methods."
4343
]
4444
},
4545
{
@@ -51,8 +51,8 @@
5151
"source": [
5252
"from cirq_ft import Register, Registers, infra\n",
5353
"\n",
54-
"control_reg = Register(name='control', shape=(2,))\n",
55-
"target_reg = Register(name='target', shape=(3,))\n",
54+
"control_reg = Register(name='control', bitsize=2)\n",
55+
"target_reg = Register(name='target', bitsize=3)\n",
5656
"control_reg, target_reg"
5757
]
5858
},

cirq-ft/cirq_ft/infra/gate_with_registers.py

+28-16
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@ class Register:
3232
"""
3333

3434
name: str
35+
bitsize: int
3536
shape: Tuple[int, ...] = attr.field(
36-
converter=lambda v: (v,) if isinstance(v, int) else tuple(v)
37+
converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=()
3738
)
3839

3940
def all_idxs(self) -> Iterable[Tuple[int, ...]]:
@@ -45,15 +46,14 @@ def total_bits(self) -> int:
4546
4647
This is the product of each of the dimensions in `shape`.
4748
"""
48-
return int(np.product(self.shape))
49+
return self.bitsize * int(np.product(self.shape))
4950

5051
def __repr__(self):
51-
return f'cirq_ft.Register(name="{self.name}", shape={self.shape})'
52+
return f'cirq_ft.Register(name="{self.name}", bitsize={self.bitsize}, shape={self.shape})'
5253

5354

5455
def total_bits(registers: Iterable[Register]) -> int:
5556
"""Sum of `reg.total_bits()` for each register `reg` in input `registers`."""
56-
5757
return sum(reg.total_bits() for reg in registers)
5858

5959

@@ -65,7 +65,9 @@ def split_qubits(
6565
qubit_regs = {}
6666
base = 0
6767
for reg in registers:
68-
qubit_regs[reg.name] = np.array(qubits[base : base + reg.total_bits()]).reshape(reg.shape)
68+
qubit_regs[reg.name] = np.array(qubits[base : base + reg.total_bits()]).reshape(
69+
reg.shape + (reg.bitsize,)
70+
)
6971
base += reg.total_bits()
7072
return qubit_regs
7173

@@ -82,9 +84,10 @@ def merge_qubits(
8284
raise ValueError(f"All qubit registers must be present. {reg.name} not in qubit_regs")
8385
qubits = qubit_regs[reg.name]
8486
qubits = np.array([qubits] if isinstance(qubits, cirq.Qid) else qubits)
85-
if qubits.shape != reg.shape:
87+
full_shape = reg.shape + (reg.bitsize,)
88+
if qubits.shape != full_shape:
8689
raise ValueError(
87-
f'{reg.name} register must of shape {reg.shape} but is of shape {qubits.shape}'
90+
f'{reg.name} register must of shape {full_shape} but is of shape {qubits.shape}'
8891
)
8992
ret += qubits.flatten().tolist()
9093
return ret
@@ -94,13 +97,16 @@ def get_named_qubits(registers: Iterable[Register]) -> Dict[str, NDArray[cirq.Qi
9497
"""Returns a dictionary of appropriately shaped named qubit registers for input `registers`."""
9598

9699
def _qubit_array(reg: Register):
97-
qubits = np.empty(reg.shape, dtype=object)
100+
qubits = np.empty(reg.shape + (reg.bitsize,), dtype=object)
98101
for ii in reg.all_idxs():
99-
qubits[ii] = cirq.NamedQubit(f'{reg.name}[{", ".join(str(i) for i in ii)}]')
102+
for j in range(reg.bitsize):
103+
prefix = "" if not ii else f'[{", ".join(str(i) for i in ii)}]'
104+
suffix = "" if reg.bitsize == 1 else f"[{j}]"
105+
qubits[ii + (j,)] = cirq.NamedQubit(reg.name + prefix + suffix)
100106
return qubits
101107

102108
def _qubits_for_reg(reg: Register):
103-
if len(reg.shape) > 1:
109+
if len(reg.shape) > 0:
104110
return _qubit_array(reg)
105111

106112
return np.array(
@@ -130,8 +136,8 @@ def __repr__(self):
130136
return f'cirq_ft.Registers({self._registers})'
131137

132138
@classmethod
133-
def build(cls, **registers: Union[int, Tuple[int, ...]]) -> 'Registers':
134-
return cls(Register(name=k, shape=v) for k, v in registers.items())
139+
def build(cls, **registers: int) -> 'Registers':
140+
return cls(Register(name=k, bitsize=v) for k, v in registers.items())
135141

136142
@overload
137143
def __getitem__(self, key: int) -> Register:
@@ -216,23 +222,29 @@ class SelectionRegister(Register):
216222
>>> assert len(flat_indices) == N * M * L
217223
"""
218224

225+
name: str
226+
bitsize: int
219227
iteration_length: int = attr.field()
228+
shape: Tuple[int, ...] = attr.field(
229+
converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=()
230+
)
220231

221232
@iteration_length.default
222233
def _default_iteration_length(self):
223-
return 2 ** self.shape[0]
234+
return 2**self.bitsize
224235

225236
@iteration_length.validator
226237
def validate_iteration_length(self, attribute, value):
227-
if len(self.shape) != 1:
238+
if len(self.shape) != 0:
228239
raise ValueError(f'Selection register {self.name} should be flat. Found {self.shape=}')
229-
if not (0 <= value <= 2 ** self.shape[0]):
230-
raise ValueError(f'iteration length must be in range [0, 2^{self.shape[0]}]')
240+
if not (0 <= value <= 2**self.bitsize):
241+
raise ValueError(f'iteration length must be in range [0, 2^{self.bitsize}]')
231242

232243
def __repr__(self) -> str:
233244
return (
234245
f'cirq_ft.SelectionRegister('
235246
f'name="{self.name}", '
247+
f'bitsize={self.bitsize}, '
236248
f'shape={self.shape}, '
237249
f'iteration_length={self.iteration_length})'
238250
)

cirq-ft/cirq_ft/infra/gate_with_registers_test.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121

2222

2323
def test_register():
24-
r = cirq_ft.Register("my_reg", 5)
25-
assert r.shape == (5,)
24+
r = cirq_ft.Register("my_reg", 5, (1, 2))
25+
assert r.bitsize == 5
26+
assert r.shape == (1, 2)
2627

2728

2829
def test_registers():
@@ -103,12 +104,12 @@ def test_selection_registers_consistent():
103104
_ = cirq_ft.SelectionRegister('a', 3, 10)
104105

105106
with pytest.raises(ValueError, match="should be flat"):
106-
_ = cirq_ft.SelectionRegister('a', (3, 5), 5)
107+
_ = cirq_ft.SelectionRegister('a', bitsize=1, shape=(3, 5), iteration_length=5)
107108

108109
selection_reg = cirq_ft.Registers(
109110
[
110-
cirq_ft.SelectionRegister('n', shape=3, iteration_length=5),
111-
cirq_ft.SelectionRegister('m', shape=4, iteration_length=12),
111+
cirq_ft.SelectionRegister('n', bitsize=3, iteration_length=5),
112+
cirq_ft.SelectionRegister('m', bitsize=4, iteration_length=12),
112113
]
113114
)
114115
assert selection_reg[0] == cirq_ft.SelectionRegister('n', 3, 5)
@@ -122,7 +123,9 @@ def test_registers_getitem_raises():
122123
with pytest.raises(IndexError, match="must be of the type"):
123124
_ = g[2.5]
124125

125-
selection_reg = cirq_ft.Registers([cirq_ft.SelectionRegister('n', shape=3, iteration_length=5)])
126+
selection_reg = cirq_ft.Registers(
127+
[cirq_ft.SelectionRegister('n', bitsize=3, iteration_length=5)]
128+
)
126129
with pytest.raises(IndexError, match='must be of the type'):
127130
_ = selection_reg[2.5]
128131

0 commit comments

Comments
 (0)