Skip to content

Commit acbc624

Browse files
authored
Do not allow creating registers with bitsize 0 (#6298)
* Do not allow creating registers with bitsize 0 * Fix mypy errors
1 parent 8e4e7d1 commit acbc624

8 files changed

+25
-11
lines changed

Diff for: cirq-ft/cirq_ft/algos/and_gate.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,11 @@ def _decompose_via_tree(
141141
def decompose_from_registers(
142142
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
143143
) -> cirq.OP_TREE:
144-
control, ancilla, target = quregs['control'], quregs['ancilla'], quregs['target']
144+
control, ancilla, target = (
145+
quregs['control'],
146+
quregs.get('ancilla', np.array([])),
147+
quregs['target'],
148+
)
145149
if len(self.cv) == 2:
146150
yield self._decompose_single_and(
147151
self.cv[0], self.cv[1], control[0], control[1], *target

Diff for: cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]:
7777

7878
@cached_property
7979
def target_registers(self) -> Tuple[infra.Register, ...]:
80-
total_iteration_size = np.product(
80+
total_iteration_size = np.prod(
8181
tuple(reg.iteration_length for reg in self.selection_registers)
8282
)
8383
return (infra.Register('target', int(total_iteration_size)),)

Diff for: cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def decompose_from_registers(
6969
context: cirq.DecompositionContext,
7070
**quregs: NDArray[cirq.Qid], # type:ignore[type-var]
7171
) -> cirq.OP_TREE:
72-
controls, target = quregs['controls'], quregs['target']
72+
controls, target = quregs.get('controls', ()), quregs['target']
7373
# Find K and L as per https://arxiv.org/abs/1805.03662 Fig 12.
7474
n, k = self.n, 0
7575
while n > 1 and n % 2 == 0:

Diff for: cirq-ft/cirq_ft/algos/qrom.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]:
111111

112112
@cached_property
113113
def target_registers(self) -> Tuple[infra.Register, ...]:
114-
return tuple(infra.Register(f'target{i}', l) for i, l in enumerate(self.target_bitsizes))
114+
return tuple(
115+
infra.Register(f'target{i}', l) for i, l in enumerate(self.target_bitsizes) if l
116+
)
115117

116118
def __repr__(self) -> str:
117119
data_repr = f"({','.join(cirq._compat.proper_repr(d) for d in self.data)})"
@@ -129,7 +131,7 @@ def _load_nth_data(
129131
**target_regs: NDArray[cirq.Qid], # type: ignore[type-var]
130132
) -> cirq.OP_TREE:
131133
for i, d in enumerate(self.data):
132-
target = target_regs[f'target{i}']
134+
target = target_regs.get(f'target{i}', ())
133135
for q, bit in zip(target, f'{int(d[selection_idx]):0{len(target)}b}'):
134136
if int(bit):
135137
yield gate(q)

Diff for: cirq-ft/cirq_ft/algos/selected_majorana_fermion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]:
7777

7878
@cached_property
7979
def target_registers(self) -> Tuple[infra.Register, ...]:
80-
total_iteration_size = np.product(
80+
total_iteration_size = np.prod(
8181
tuple(reg.iteration_length for reg in self.selection_registers)
8282
)
8383
return (infra.Register('target', int(total_iteration_size)),)

Diff for: cirq-ft/cirq_ft/algos/state_preparation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def decompose_from_registers(
167167
**quregs: NDArray[cirq.Qid], # type:ignore[type-var]
168168
) -> cirq.OP_TREE:
169169
selection, less_than_equal = quregs['selection'], quregs['less_than_equal']
170-
sigma_mu, alt, keep = quregs['sigma_mu'], quregs['alt'], quregs['keep']
170+
sigma_mu, alt, keep = quregs.get('sigma_mu', ()), quregs['alt'], quregs.get('keep', ())
171171
N = self.selection_registers[0].iteration_length
172172
yield prepare_uniform_superposition.PrepareUniformSuperposition(N).on(*selection)
173173
yield cirq.H.on_each(*sigma_mu)

Diff for: cirq-ft/cirq_ft/infra/gate_with_registers.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,16 @@ class Register:
3232
"""
3333

3434
name: str
35-
bitsize: int
35+
bitsize: int = attr.field()
3636
shape: Tuple[int, ...] = attr.field(
3737
converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=()
3838
)
3939

40+
@bitsize.validator
41+
def bitsize_validator(self, attribute, value):
42+
if value <= 0:
43+
raise ValueError(f"Bitsize for {self=} must be a positive integer. Found {value}.")
44+
4045
def all_idxs(self) -> Iterable[Tuple[int, ...]]:
4146
"""Iterate over all possible indices of a multidimensional register."""
4247
yield from itertools.product(*[range(sh) for sh in self.shape])
@@ -46,7 +51,7 @@ def total_bits(self) -> int:
4651
4752
This is the product of each of the dimensions in `shape`.
4853
"""
49-
return self.bitsize * int(np.product(self.shape))
54+
return self.bitsize * int(np.prod(self.shape))
5055

5156
def __repr__(self):
5257
return f'cirq_ft.Register(name="{self.name}", bitsize={self.bitsize}, shape={self.shape})'
@@ -137,7 +142,7 @@ def __repr__(self):
137142

138143
@classmethod
139144
def build(cls, **registers: int) -> 'Registers':
140-
return cls(Register(name=k, bitsize=v) for k, v in registers.items())
145+
return cls(Register(name=k, bitsize=v) for k, v in registers.items() if v > 0)
141146

142147
@overload
143148
def __getitem__(self, key: int) -> Register:

Diff for: cirq-ft/cirq_ft/infra/gate_with_registers_test.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ def test_register():
2525
assert r.bitsize == 5
2626
assert r.shape == (1, 2)
2727

28+
with pytest.raises(ValueError, match="must be a positive integer"):
29+
_ = cirq_ft.Register("zero bitsize register", bitsize=0)
30+
2831

2932
def test_registers():
3033
r1 = cirq_ft.Register("r1", 5)
@@ -96,7 +99,7 @@ def test_selection_registers_indexing(n, N, m, M):
9699
assert np.ravel_multi_index((x, y), (N, M)) == x * M + y
97100
assert np.unravel_index(x * M + y, (N, M)) == (x, y)
98101

99-
assert np.product(tuple(reg.iteration_length for reg in regs)) == N * M
102+
assert np.prod(tuple(reg.iteration_length for reg in regs)) == N * M
100103

101104

102105
def test_selection_registers_consistent():

0 commit comments

Comments
 (0)