From c805a6fa12b2ebd5d8e7b8e87268d0f4c36ce2b1 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Thu, 21 Sep 2023 15:45:17 -0700 Subject: [PATCH 1/2] Do not allow creating registers with bitsize 0 --- cirq-ft/cirq_ft/algos/and_gate.py | 2 +- cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py | 2 +- cirq-ft/cirq_ft/algos/qrom.py | 6 ++++-- cirq-ft/cirq_ft/algos/state_preparation.py | 2 +- cirq-ft/cirq_ft/infra/gate_with_registers.py | 9 +++++++-- cirq-ft/cirq_ft/infra/gate_with_registers_test.py | 3 +++ 6 files changed, 17 insertions(+), 7 deletions(-) diff --git a/cirq-ft/cirq_ft/algos/and_gate.py b/cirq-ft/cirq_ft/algos/and_gate.py index f308926d632..d8c1b79cbd9 100644 --- a/cirq-ft/cirq_ft/algos/and_gate.py +++ b/cirq-ft/cirq_ft/algos/and_gate.py @@ -141,7 +141,7 @@ def _decompose_via_tree( def decompose_from_registers( self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: - control, ancilla, target = quregs['control'], quregs['ancilla'], quregs['target'] + control, ancilla, target = quregs['control'], quregs.get('ancilla', ()), quregs['target'] if len(self.cv) == 2: yield self._decompose_single_and( self.cv[0], self.cv[1], control[0], control[1], *target diff --git a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py index 374415e90bc..6497e3d65c5 100644 --- a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py +++ b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py @@ -69,7 +69,7 @@ def decompose_from_registers( context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid], # type:ignore[type-var] ) -> cirq.OP_TREE: - controls, target = quregs['controls'], quregs['target'] + controls, target = quregs.get('controls', ()), quregs['target'] # Find K and L as per https://arxiv.org/abs/1805.03662 Fig 12. n, k = self.n, 0 while n > 1 and n % 2 == 0: diff --git a/cirq-ft/cirq_ft/algos/qrom.py b/cirq-ft/cirq_ft/algos/qrom.py index 8d09d82ed9b..9feb90ad125 100644 --- a/cirq-ft/cirq_ft/algos/qrom.py +++ b/cirq-ft/cirq_ft/algos/qrom.py @@ -111,7 +111,9 @@ def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: @cached_property def target_registers(self) -> Tuple[infra.Register, ...]: - return tuple(infra.Register(f'target{i}', l) for i, l in enumerate(self.target_bitsizes)) + return tuple( + infra.Register(f'target{i}', l) for i, l in enumerate(self.target_bitsizes) if l + ) def __repr__(self) -> str: data_repr = f"({','.join(cirq._compat.proper_repr(d) for d in self.data)})" @@ -129,7 +131,7 @@ def _load_nth_data( **target_regs: NDArray[cirq.Qid], # type: ignore[type-var] ) -> cirq.OP_TREE: for i, d in enumerate(self.data): - target = target_regs[f'target{i}'] + target = target_regs.get(f'target{i}', ()) for q, bit in zip(target, f'{int(d[selection_idx]):0{len(target)}b}'): if int(bit): yield gate(q) diff --git a/cirq-ft/cirq_ft/algos/state_preparation.py b/cirq-ft/cirq_ft/algos/state_preparation.py index bec54f50a6b..aa660b5ebf8 100644 --- a/cirq-ft/cirq_ft/algos/state_preparation.py +++ b/cirq-ft/cirq_ft/algos/state_preparation.py @@ -167,7 +167,7 @@ def decompose_from_registers( **quregs: NDArray[cirq.Qid], # type:ignore[type-var] ) -> cirq.OP_TREE: selection, less_than_equal = quregs['selection'], quregs['less_than_equal'] - sigma_mu, alt, keep = quregs['sigma_mu'], quregs['alt'], quregs['keep'] + sigma_mu, alt, keep = quregs.get('sigma_mu', ()), quregs['alt'], quregs.get('keep', ()) N = self.selection_registers[0].iteration_length yield prepare_uniform_superposition.PrepareUniformSuperposition(N).on(*selection) yield cirq.H.on_each(*sigma_mu) diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers.py b/cirq-ft/cirq_ft/infra/gate_with_registers.py index b4567591c67..d299bfe0eac 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers.py +++ b/cirq-ft/cirq_ft/infra/gate_with_registers.py @@ -32,11 +32,16 @@ class Register: """ name: str - bitsize: int + bitsize: int = attr.field() shape: Tuple[int, ...] = attr.field( converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=() ) + @bitsize.validator + def bitsize_validator(self, attribute, value): + if value <= 0: + raise ValueError(f"Bitsize for {self=} must be a positive integer. Found {value}.") + def all_idxs(self) -> Iterable[Tuple[int, ...]]: """Iterate over all possible indices of a multidimensional register.""" yield from itertools.product(*[range(sh) for sh in self.shape]) @@ -137,7 +142,7 @@ def __repr__(self): @classmethod def build(cls, **registers: int) -> 'Registers': - return cls(Register(name=k, bitsize=v) for k, v in registers.items()) + return cls(Register(name=k, bitsize=v) for k, v in registers.items() if v > 0) @overload def __getitem__(self, key: int) -> Register: diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers_test.py b/cirq-ft/cirq_ft/infra/gate_with_registers_test.py index 57af2354e48..2cd72e4c1fa 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers_test.py +++ b/cirq-ft/cirq_ft/infra/gate_with_registers_test.py @@ -25,6 +25,9 @@ def test_register(): assert r.bitsize == 5 assert r.shape == (1, 2) + with pytest.raises(ValueError, match="must be a positive integer"): + _ = cirq_ft.Register("zero bitsize register", bitsize=0) + def test_registers(): r1 = cirq_ft.Register("r1", 5) From a63eee21db113c627e67ece48d44c8dc5253e24c Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Sat, 23 Sep 2023 22:27:47 -0700 Subject: [PATCH 2/2] Fix mypy errors --- cirq-ft/cirq_ft/algos/and_gate.py | 6 +++++- cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py | 2 +- cirq-ft/cirq_ft/algos/selected_majorana_fermion.py | 2 +- cirq-ft/cirq_ft/infra/gate_with_registers.py | 2 +- cirq-ft/cirq_ft/infra/gate_with_registers_test.py | 2 +- 5 files changed, 9 insertions(+), 5 deletions(-) diff --git a/cirq-ft/cirq_ft/algos/and_gate.py b/cirq-ft/cirq_ft/algos/and_gate.py index d8c1b79cbd9..b34f632f5ff 100644 --- a/cirq-ft/cirq_ft/algos/and_gate.py +++ b/cirq-ft/cirq_ft/algos/and_gate.py @@ -141,7 +141,11 @@ def _decompose_via_tree( def decompose_from_registers( self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: - control, ancilla, target = quregs['control'], quregs.get('ancilla', ()), quregs['target'] + control, ancilla, target = ( + quregs['control'], + quregs.get('ancilla', np.array([])), + quregs['target'], + ) if len(self.cv) == 2: yield self._decompose_single_and( self.cv[0], self.cv[1], control[0], control[1], *target diff --git a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py index e3bb08be143..25f80dfa00d 100644 --- a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py +++ b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py @@ -77,7 +77,7 @@ def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: @cached_property def target_registers(self) -> Tuple[infra.Register, ...]: - total_iteration_size = np.product( + total_iteration_size = np.prod( tuple(reg.iteration_length for reg in self.selection_registers) ) return (infra.Register('target', int(total_iteration_size)),) diff --git a/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py b/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py index 877c81f39a3..a97eb752adb 100644 --- a/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py +++ b/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py @@ -77,7 +77,7 @@ def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: @cached_property def target_registers(self) -> Tuple[infra.Register, ...]: - total_iteration_size = np.product( + total_iteration_size = np.prod( tuple(reg.iteration_length for reg in self.selection_registers) ) return (infra.Register('target', int(total_iteration_size)),) diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers.py b/cirq-ft/cirq_ft/infra/gate_with_registers.py index d299bfe0eac..624397ab479 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers.py +++ b/cirq-ft/cirq_ft/infra/gate_with_registers.py @@ -51,7 +51,7 @@ def total_bits(self) -> int: This is the product of each of the dimensions in `shape`. """ - return self.bitsize * int(np.product(self.shape)) + return self.bitsize * int(np.prod(self.shape)) def __repr__(self): return f'cirq_ft.Register(name="{self.name}", bitsize={self.bitsize}, shape={self.shape})' diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers_test.py b/cirq-ft/cirq_ft/infra/gate_with_registers_test.py index 2cd72e4c1fa..77e60aacbe8 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers_test.py +++ b/cirq-ft/cirq_ft/infra/gate_with_registers_test.py @@ -99,7 +99,7 @@ def test_selection_registers_indexing(n, N, m, M): assert np.ravel_multi_index((x, y), (N, M)) == x * M + y assert np.unravel_index(x * M + y, (N, M)) == (x, y) - assert np.product(tuple(reg.iteration_length for reg in regs)) == N * M + assert np.prod(tuple(reg.iteration_length for reg in regs)) == N * M def test_selection_registers_consistent():