Skip to content

Commit 6fae409

Browse files
authored
Delete SelectionRegisters and replace uses of Registers with Tuple[Register, ...] (#6278)
* Delete SelectionRegisters and replace uses of Registers with Tuple[Register, ...] * Add type ignore to fix mypy error * Address Matt's comments
1 parent e235642 commit 6fae409

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+488
-553
lines changed

cirq-ft/cirq_ft/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
Register,
5050
Registers,
5151
SelectionRegister,
52-
SelectionRegisters,
5352
TComplexity,
5453
map_clean_and_borrowable_qubits,
5554
t_complexity,

cirq-ft/cirq_ft/algos/and_gate.ipynb

+3-3
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,11 @@
6363
"source": [
6464
"import cirq\n",
6565
"from cirq.contrib.svg import SVGCircuit\n",
66-
"from cirq_ft import And\n",
66+
"from cirq_ft import And, infra\n",
6767
"\n",
6868
"gate = And()\n",
6969
"r = gate.registers\n",
70-
"quregs = r.get_named_qubits()\n",
70+
"quregs = infra.get_named_qubits(r)\n",
7171
"operation = gate.on_registers(**quregs)\n",
7272
"circuit = cirq.Circuit(operation)\n",
7373
"SVGCircuit(circuit)"
@@ -223,4 +223,4 @@
223223
},
224224
"nbformat": 4,
225225
"nbformat_minor": 5
226-
}
226+
}

cirq-ft/cirq_ft/algos/and_gate.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ class And(infra.GateWithRegisters):
4949
ValueError: If number of control values (i.e. `len(self.cv)`) is less than 2.
5050
"""
5151

52-
cv: Tuple[int, ...] = attr.field(default=(1, 1), converter=infra.to_tuple)
52+
cv: Tuple[int, ...] = attr.field(
53+
default=(1, 1), converter=lambda v: (v,) if isinstance(v, int) else tuple(v)
54+
)
5355
adjoint: bool = False
5456

5557
@cv.validator

cirq-ft/cirq_ft/algos/and_gate_test.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import cirq_ft
2121
import numpy as np
2222
import pytest
23+
from cirq_ft import infra
2324
from cirq_ft.infra.jupyter_tools import execute_notebook
2425

2526
random.seed(12345)
@@ -46,12 +47,12 @@ def test_multi_controlled_and_gate(cv: List[int]):
4647
gate = cirq_ft.And(cv)
4748
r = gate.registers
4849
assert r['ancilla'].total_bits() == r['control'].total_bits() - 2
49-
quregs = r.get_named_qubits()
50+
quregs = infra.get_named_qubits(r)
5051
and_op = gate.on_registers(**quregs)
5152
circuit = cirq.Circuit(and_op)
5253

5354
input_controls = [cv] + [random_cv(len(cv)) for _ in range(10)]
54-
qubit_order = gate.registers.merge_qubits(**quregs)
55+
qubit_order = infra.merge_qubits(gate.registers, **quregs)
5556

5657
for input_control in input_controls:
5758
initial_state = input_control + [0] * (r['ancilla'].total_bits() + 1)
@@ -77,7 +78,7 @@ def test_multi_controlled_and_gate(cv: List[int]):
7778

7879
def test_and_gate_diagram():
7980
gate = cirq_ft.And((1, 0, 1, 0, 1, 0))
80-
qubit_regs = gate.registers.get_named_qubits()
81+
qubit_regs = infra.get_named_qubits(gate.registers)
8182
op = gate.on_registers(**qubit_regs)
8283
# Qubit order should be alternating (control, ancilla) pairs.
8384
c_and_a = sum(zip(qubit_regs["control"][1:], qubit_regs["ancilla"]), ()) + (

cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
"`selection`-th qubit of `target` all controlled by the `control` register.\n",
7070
"\n",
7171
"#### Parameters\n",
72-
" - `selection_regs`: Indexing `select` registers of type `SelectionRegisters`. It also contains information about the iteration length of each selection register.\n",
72+
" - `selection_regs`: Indexing `select` registers of type Tuple[`SelectionRegister`, ...]. It also contains information about the iteration length of each selection register.\n",
7373
" - `nth_gate`: A function mapping the composite selection index to a single-qubit gate.\n",
7474
" - `control_regs`: Control registers for constructing a controlled version of the gate.\n"
7575
]
@@ -89,7 +89,7 @@
8989
" return cirq.I\n",
9090
"\n",
9191
"apply_z_to_odd = cirq_ft.ApplyGateToLthQubit(\n",
92-
" cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 4)]),\n",
92+
" cirq_ft.SelectionRegister('selection', 3, 4),\n",
9393
" nth_gate=_z_to_odd,\n",
9494
" control_regs=cirq_ft.Registers.build(control=2),\n",
9595
")\n",

cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py

+26-22
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
# limitations under the License.
1414

1515
import itertools
16-
from typing import Callable, Sequence
16+
from typing import Callable, Sequence, Tuple
1717

1818
import attr
1919
import cirq
20+
import numpy as np
2021
from cirq._compat import cached_property
2122
from cirq_ft import infra
2223
from cirq_ft.algos import unary_iteration_gate
@@ -36,8 +37,8 @@ class ApplyGateToLthQubit(unary_iteration_gate.UnaryIterationGate):
3637
`selection`-th qubit of `target` all controlled by the `control` register.
3738
3839
Args:
39-
selection_regs: Indexing `select` registers of type `SelectionRegisters`. It also contains
40-
information about the iteration length of each selection register.
40+
selection_regs: Indexing `select` registers of type Tuple[`SelectionRegisters`, ...].
41+
It also contains information about the iteration length of each selection register.
4142
nth_gate: A function mapping the composite selection index to a single-qubit gate.
4243
control_regs: Control registers for constructing a controlled version of the gate.
4344
@@ -46,43 +47,45 @@ class ApplyGateToLthQubit(unary_iteration_gate.UnaryIterationGate):
4647
(https://arxiv.org/abs/1805.03662).
4748
Babbush et. al. (2018). Section III.A. and Figure 7.
4849
"""
49-
selection_regs: infra.SelectionRegisters
50+
selection_regs: Tuple[infra.SelectionRegister, ...] = attr.field(
51+
converter=lambda v: (v,) if isinstance(v, infra.SelectionRegister) else tuple(v)
52+
)
5053
nth_gate: Callable[..., cirq.Gate]
51-
control_regs: infra.Registers = infra.Registers.build(control=1)
54+
control_regs: Tuple[infra.Register, ...] = attr.field(
55+
converter=lambda v: (v,) if isinstance(v, infra.Register) else tuple(v),
56+
default=(infra.Register('control', 1),),
57+
)
5258

5359
@classmethod
5460
def make_on(
5561
cls, *, nth_gate: Callable[..., cirq.Gate], **quregs: Sequence[cirq.Qid]
5662
) -> cirq.Operation:
5763
"""Helper constructor to automatically deduce bitsize attributes."""
58-
return cls(
59-
infra.SelectionRegisters(
60-
[
61-
infra.SelectionRegister(
62-
'selection', len(quregs['selection']), len(quregs['target'])
63-
)
64-
]
65-
),
64+
return ApplyGateToLthQubit(
65+
infra.SelectionRegister('selection', len(quregs['selection']), len(quregs['target'])),
6666
nth_gate=nth_gate,
67-
control_regs=infra.Registers.build(control=len(quregs['control'])),
67+
control_regs=infra.Register('control', len(quregs['control'])),
6868
).on_registers(**quregs)
6969

7070
@cached_property
71-
def control_registers(self) -> infra.Registers:
71+
def control_registers(self) -> Tuple[infra.Register, ...]:
7272
return self.control_regs
7373

7474
@cached_property
75-
def selection_registers(self) -> infra.SelectionRegisters:
75+
def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]:
7676
return self.selection_regs
7777

7878
@cached_property
79-
def target_registers(self) -> infra.Registers:
80-
return infra.Registers.build(target=self.selection_registers.total_iteration_size)
79+
def target_registers(self) -> Tuple[infra.Register, ...]:
80+
total_iteration_size = np.product(
81+
tuple(reg.iteration_length for reg in self.selection_registers)
82+
)
83+
return (infra.Register('target', int(total_iteration_size)),)
8184

8285
def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
83-
wire_symbols = ["@"] * self.control_registers.total_bits()
84-
wire_symbols += ["In"] * self.selection_registers.total_bits()
85-
for it in itertools.product(*[range(x) for x in self.selection_regs.iteration_lengths]):
86+
wire_symbols = ["@"] * infra.total_bits(self.control_registers)
87+
wire_symbols += ["In"] * infra.total_bits(self.selection_registers)
88+
for it in itertools.product(*[range(reg.iteration_length) for reg in self.selection_regs]):
8689
wire_symbols += [str(self.nth_gate(*it))]
8790
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
8891

@@ -93,6 +96,7 @@ def nth_operation( # type: ignore[override]
9396
target: Sequence[cirq.Qid],
9497
**selection_indices: int,
9598
) -> cirq.OP_TREE:
99+
selection_shape = tuple(reg.iteration_length for reg in self.selection_regs)
96100
selection_idx = tuple(selection_indices[reg.name] for reg in self.selection_regs)
97-
target_idx = self.selection_registers.to_flat_idx(*selection_idx)
101+
target_idx = int(np.ravel_multi_index(selection_idx, selection_shape))
98102
return self.nth_gate(*selection_idx).on(target[target_idx]).controlled_by(control)

cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import cirq
1616
import cirq_ft
1717
import pytest
18+
from cirq_ft import infra
1819
from cirq_ft.infra.bit_tools import iter_bits
1920
from cirq_ft.infra.jupyter_tools import execute_notebook
2021

@@ -23,16 +24,13 @@
2324
def test_apply_gate_to_lth_qubit(selection_bitsize, target_bitsize):
2425
greedy_mm = cirq_ft.GreedyQubitManager(prefix="_a", maximize_reuse=True)
2526
gate = cirq_ft.ApplyGateToLthQubit(
26-
cirq_ft.SelectionRegisters(
27-
[cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)]
28-
),
29-
lambda _: cirq.X,
27+
cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize), lambda _: cirq.X
3028
)
3129
g = cirq_ft.testing.GateHelper(gate, context=cirq.DecompositionContext(greedy_mm))
3230
# Upper bounded because not all ancillas may be used as part of unary iteration.
3331
assert (
3432
len(g.all_qubits)
35-
<= target_bitsize + 2 * (selection_bitsize + gate.control_registers.total_bits()) - 1
33+
<= target_bitsize + 2 * (selection_bitsize + infra.total_bits(gate.control_registers)) - 1
3634
)
3735

3836
for n in range(target_bitsize):
@@ -54,12 +52,12 @@ def test_apply_gate_to_lth_qubit(selection_bitsize, target_bitsize):
5452
def test_apply_gate_to_lth_qubit_diagram():
5553
# Apply Z gate to all odd targets and Identity to even targets.
5654
gate = cirq_ft.ApplyGateToLthQubit(
57-
cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 5)]),
55+
cirq_ft.SelectionRegister('selection', 3, 5),
5856
lambda n: cirq.Z if n & 1 else cirq.I,
5957
control_regs=cirq_ft.Registers.build(control=2),
6058
)
61-
circuit = cirq.Circuit(gate.on_registers(**gate.registers.get_named_qubits()))
62-
qubits = list(q for v in gate.registers.get_named_qubits().values() for q in v)
59+
circuit = cirq.Circuit(gate.on_registers(**infra.get_named_qubits(gate.registers)))
60+
qubits = list(q for v in infra.get_named_qubits(gate.registers).values() for q in v)
6361
cirq.testing.assert_has_diagram(
6462
circuit,
6563
"""
@@ -89,13 +87,13 @@ def test_apply_gate_to_lth_qubit_diagram():
8987

9088
def test_apply_gate_to_lth_qubit_make_on():
9189
gate = cirq_ft.ApplyGateToLthQubit(
92-
cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 5)]),
90+
cirq_ft.SelectionRegister('selection', 3, 5),
9391
lambda n: cirq.Z if n & 1 else cirq.I,
9492
control_regs=cirq_ft.Registers.build(control=2),
9593
)
96-
op = gate.on_registers(**gate.registers.get_named_qubits())
94+
op = gate.on_registers(**infra.get_named_qubits(gate.registers))
9795
op2 = cirq_ft.ApplyGateToLthQubit.make_on(
98-
nth_gate=lambda n: cirq.Z if n & 1 else cirq.I, **gate.registers.get_named_qubits()
96+
nth_gate=lambda n: cirq.Z if n & 1 else cirq.I, **infra.get_named_qubits(gate.registers)
9997
)
10098
# Note: ApplyGateToLthQubit doesn't support value equality.
10199
assert op.qubits == op2.qubits

cirq-ft/cirq_ft/algos/arithmetic_gates.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,9 @@ class AddMod(cirq.ArithmeticGate):
611611
bitsize: int
612612
mod: int = attr.field()
613613
add_val: int = 1
614-
cv: Tuple[int, ...] = attr.field(converter=infra.to_tuple, default=())
614+
cv: Tuple[int, ...] = attr.field(
615+
converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=()
616+
)
615617

616618
@mod.validator
617619
def _validate_mod(self, attribute, value):

cirq-ft/cirq_ft/algos/generic_select.py

+9-12
Original file line numberDiff line numberDiff line change
@@ -68,23 +68,20 @@ def __attrs_post_init__(self):
6868
)
6969

7070
@cached_property
71-
def control_registers(self) -> infra.Registers:
72-
registers = [] if self.control_val is None else [infra.Register('control', 1)]
73-
return infra.Registers(registers)
71+
def control_registers(self) -> Tuple[infra.Register, ...]:
72+
return () if self.control_val is None else (infra.Register('control', 1),)
7473

7574
@cached_property
76-
def selection_registers(self) -> infra.SelectionRegisters:
77-
return infra.SelectionRegisters(
78-
[
79-
infra.SelectionRegister(
80-
'selection', self.selection_bitsize, len(self.select_unitaries)
81-
)
82-
]
75+
def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]:
76+
return (
77+
infra.SelectionRegister(
78+
'selection', self.selection_bitsize, len(self.select_unitaries)
79+
),
8380
)
8481

8582
@cached_property
86-
def target_registers(self) -> infra.Registers:
87-
return infra.Registers.build(target=self.target_bitsize)
83+
def target_registers(self) -> Tuple[infra.Register, ...]:
84+
return (infra.Register('target', self.target_bitsize),)
8885

8986
def decompose_from_registers(
9087
self, context, **quregs: NDArray[cirq.Qid] # type:ignore[type-var]

cirq-ft/cirq_ft/algos/generic_select_test.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import cirq_ft
1818
import numpy as np
1919
import pytest
20+
from cirq_ft import infra
2021
from cirq_ft.infra.bit_tools import iter_bits
2122
from cirq_ft.infra.jupyter_tools import execute_notebook
2223

@@ -255,7 +256,7 @@ def test_generic_select_consistent_protocols_and_controlled():
255256

256257
# Build GenericSelect gate.
257258
gate = cirq_ft.GenericSelect(select_bitsize, num_sites, dps_hamiltonian)
258-
op = gate.on_registers(**gate.registers.get_named_qubits())
259+
op = gate.on_registers(**infra.get_named_qubits(gate.registers))
259260
cirq.testing.assert_equivalent_repr(gate, setup_code='import cirq\nimport cirq_ft')
260261

261262
# Build controlled gate

0 commit comments

Comments
 (0)