Skip to content

Commit 425bf8d

Browse files
authored
Add _apply_channel_ optimizations for reset and confusion (#5917)
Confusion matrix and reset channels both can be viewed as starting with a zero DM tensor and then copying (adding) in scaled slices from the original DM tensor. Thus we make a helper function that does this and add `_apply_channel_` optimizations to those gates. Fixes #5901. Starts #5900 though I haven't looked at all gates. Also starts #4579 but there's likely more to do there as well. I didn't add the new test function to the primary test suite because creating superoperators is likely computationally expensive (granted most if not all gates that use this would be three or fewer qubits, which is still cheap), and the test not relevant for most gates.
1 parent 86a9017 commit 425bf8d

9 files changed

+286
-3
lines changed

cirq-core/cirq/linalg/transformations.py

+78-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
"""Utility methods for transforming matrices or vectors."""
1616

17-
from typing import Tuple, Optional, Sequence, List, Union
17+
import dataclasses
18+
from typing import Any, List, Optional, Sequence, Tuple, Union
1819

1920
import numpy as np
2021

@@ -168,6 +169,82 @@ def targeted_left_multiply(
168169
) # type: ignore
169170

170171

172+
@dataclasses.dataclass
173+
class _SliceConfig:
174+
axis: int
175+
source_index: int
176+
target_index: int
177+
178+
179+
@dataclasses.dataclass
180+
class _BuildFromSlicesArgs:
181+
slices: Tuple[_SliceConfig, ...]
182+
scale: complex
183+
184+
185+
def _build_from_slices(
186+
args: Sequence[_BuildFromSlicesArgs], source: np.ndarray, out: np.ndarray
187+
) -> np.ndarray:
188+
"""Populates `out` from the desired slices of `source`.
189+
190+
This function is best described by example.
191+
192+
For instance in 3*3*3 3D space, one could take a cube array, take all the horizontal slices,
193+
and add them up into the top slice leaving everything else zero. If the vertical axis was 1,
194+
and the top was index=2, then this would be written as follows:
195+
196+
_build_from_slices(
197+
[
198+
_BuildFromSlicesArgs((_SliceConfig(axis=1, source_index=0, target_index=2),), 1),
199+
_BuildFromSlicesArgs((_SliceConfig(axis=1, source_index=1, target_index=2),), 1),
200+
_BuildFromSlicesArgs((_SliceConfig(axis=1, source_index=2, target_index=2),), 1),
201+
],
202+
source,
203+
out,
204+
)
205+
206+
When multiple slices are included in the _BuildFromSlicesArgs, this means to take the
207+
intersection of the source space and move it to the intersection of the target space. For
208+
example, the following takes the bottom-left edge and moves it to the top-right, leaving all
209+
other cells zero. Assume the lateral axis is 2 and right-most index thereof is 2:
210+
211+
_build_from_slices(
212+
[
213+
_BuildFromSlicesArgs(
214+
(
215+
_SliceConfig(axis=1, source_index=0, target_index=2), # top
216+
_SliceConfig(axis=2, source_index=0, target_index=2), # right
217+
),
218+
scale=1,
219+
),
220+
],
221+
source,
222+
out,
223+
)
224+
225+
This function is useful for optimizing multiplying a state by one or more one-hot matrices,
226+
as is common when working with Kraus components. It is more efficient than using an einsum.
227+
228+
Args:
229+
args: The list of slice configurations to sum up into the output.
230+
source: The source tensor for the slice data.
231+
out: An output tensor that is the same shape as the source.
232+
233+
Returns:
234+
The output tensor.
235+
"""
236+
d = len(source.shape)
237+
out[...] = 0
238+
for arg in args:
239+
source_slice: List[Any] = [slice(None)] * d
240+
target_slice: List[Any] = [slice(None)] * d
241+
for sleis in arg.slices:
242+
source_slice[sleis.axis] = sleis.source_index
243+
target_slice[sleis.axis] = sleis.target_index
244+
out[tuple(target_slice)] += arg.scale * source[tuple(source_slice)]
245+
return out
246+
247+
171248
def targeted_conjugate_about(
172249
tensor: np.ndarray,
173250
target: np.ndarray,

cirq-core/cirq/ops/common_channels.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
import numpy as np
2121

2222
from cirq import protocols, value
23+
from cirq.linalg import transformations
2324
from cirq.ops import raw_types, common_gates, pauli_gates, identity
2425

25-
2626
if TYPE_CHECKING:
2727
import cirq
2828

@@ -734,6 +734,19 @@ def _kraus_(self) -> Iterable[np.ndarray]:
734734
channel[:, 0, :] = np.eye(self._dimension)
735735
return channel
736736

737+
def _apply_channel_(self, args: 'cirq.ApplyChannelArgs'):
738+
configs = []
739+
for i in range(self._dimension):
740+
s1 = transformations._SliceConfig(
741+
axis=args.left_axes[0], source_index=i, target_index=0
742+
)
743+
s2 = transformations._SliceConfig(
744+
axis=args.right_axes[0], source_index=i, target_index=0
745+
)
746+
configs.append(transformations._BuildFromSlicesArgs(slices=(s1, s2), scale=1))
747+
transformations._build_from_slices(configs, args.target_tensor, out=args.out_buffer)
748+
return args.out_buffer
749+
737750
def _has_kraus_(self) -> bool:
738751
return True
739752

@@ -816,6 +829,23 @@ def __init__(self, gamma: float) -> None:
816829
def _num_qubits_(self) -> int:
817830
return 1
818831

832+
def _apply_channel_(self, args: 'cirq.ApplyChannelArgs'):
833+
if self._gamma == 0:
834+
return args.target_tensor
835+
if self._gamma != 1:
836+
return NotImplemented
837+
configs = []
838+
for i in range(2):
839+
s1 = transformations._SliceConfig(
840+
axis=args.left_axes[0], source_index=i, target_index=i
841+
)
842+
s2 = transformations._SliceConfig(
843+
axis=args.right_axes[0], source_index=i, target_index=i
844+
)
845+
configs.append(transformations._BuildFromSlicesArgs(slices=(s1, s2), scale=1))
846+
transformations._build_from_slices(configs, args.target_tensor, out=args.out_buffer)
847+
return args.out_buffer
848+
819849
def _kraus_(self) -> Iterable[np.ndarray]:
820850
return (
821851
np.array([[1.0, 0.0], [0.0, np.sqrt(1.0 - self._gamma)]]),

cirq-core/cirq/ops/common_channels_test.py

+16
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,13 @@ def test_reset_each():
529529
assert op.qubits == (qubits[i],)
530530

531531

532+
def test_reset_consistency():
533+
two_d_chan = cirq.ResetChannel()
534+
cirq.testing.assert_has_consistent_apply_channel(two_d_chan)
535+
three_d_chan = cirq.ResetChannel(dimension=3)
536+
cirq.testing.assert_has_consistent_apply_channel(three_d_chan)
537+
538+
532539
def test_phase_damping_channel():
533540
d = cirq.phase_damp(0.3)
534541
np.testing.assert_almost_equal(
@@ -585,6 +592,15 @@ def test_phase_damping_channel_text_diagram():
585592
)
586593

587594

595+
def test_phase_damp_consistency():
596+
full_damp = cirq.PhaseDampingChannel(gamma=1)
597+
cirq.testing.assert_has_consistent_apply_channel(full_damp)
598+
partial_damp = cirq.PhaseDampingChannel(gamma=0.5)
599+
cirq.testing.assert_has_consistent_apply_channel(partial_damp)
600+
no_damp = cirq.PhaseDampingChannel(gamma=0)
601+
cirq.testing.assert_has_consistent_apply_channel(no_damp)
602+
603+
588604
def test_phase_flip_channel():
589605
d = cirq.phase_flip(0.3)
590606
np.testing.assert_almost_equal(

cirq-core/cirq/ops/gate_operation.py

+8
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,14 @@ def _mixture_(self) -> Sequence[Tuple[float, Any]]:
211211
return getter()
212212
return NotImplemented
213213

214+
def _apply_channel_(
215+
self, args: 'protocols.ApplyChannelArgs'
216+
) -> Union[np.ndarray, None, NotImplementedType]:
217+
getter = getattr(self.gate, '_apply_channel_', None)
218+
if getter is not None:
219+
return getter(args)
220+
return NotImplemented
221+
214222
def _has_kraus_(self) -> bool:
215223
getter = getattr(self.gate, '_has_kraus_', None)
216224
if getter is not None:

cirq-core/cirq/testing/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from cirq.testing.circuit_compare import (
1818
assert_circuits_with_terminal_measurements_are_equivalent,
1919
assert_circuits_have_same_unitary_given_final_permutation,
20+
assert_has_consistent_apply_channel,
2021
assert_has_consistent_apply_unitary,
2122
assert_has_consistent_apply_unitary_for_various_exponents,
2223
assert_has_diagram,

cirq-core/cirq/testing/circuit_compare.py

+44
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,50 @@ def assert_has_consistent_apply_unitary(val: Any, *, atol: float = 1e-8) -> None
333333
np.testing.assert_allclose(actual.reshape(n, n), expected, atol=atol)
334334

335335

336+
def assert_has_consistent_apply_channel(val: Any, *, atol: float = 1e-8) -> None:
337+
"""Tests whether a value's _apply_channel_ is correct.
338+
339+
Contrasts the effects of the value's `_apply_channel_` with the superoperator calculated from
340+
the Kraus components returned by the value's `_kraus_` method.
341+
342+
Args:
343+
val: The value under test. Should have a `__pow__` method.
344+
atol: Absolute error tolerance.
345+
"""
346+
# pylint: disable=unused-variable
347+
__tracebackhide__ = True
348+
# pylint: enable=unused-variable
349+
350+
kraus = protocols.kraus(val, default=None)
351+
expected = qis.kraus_to_superoperator(kraus) if kraus is not None else None
352+
353+
qid_shape = protocols.qid_shape(val)
354+
355+
eye = qis.eye_tensor(qid_shape * 2, dtype=np.complex128)
356+
actual = protocols.apply_channel(
357+
val=val,
358+
args=protocols.ApplyChannelArgs(
359+
target_tensor=eye,
360+
out_buffer=np.ones_like(eye) * float('nan'),
361+
auxiliary_buffer0=np.ones_like(eye) * float('nan'),
362+
auxiliary_buffer1=np.ones_like(eye) * float('nan'),
363+
left_axes=list(range(len(qid_shape))),
364+
right_axes=list(range(len(qid_shape), len(qid_shape) * 2)),
365+
),
366+
default=None,
367+
)
368+
369+
# If you don't have a Kraus, you shouldn't be able to apply a channel.
370+
if expected is None:
371+
assert actual is None
372+
373+
# If you applied a channel, it should match the superoperator you say you have.
374+
if actual is not None:
375+
assert expected is not None
376+
n = np.product(qid_shape) ** 2
377+
np.testing.assert_allclose(actual.reshape((n, n)), expected, atol=atol)
378+
379+
336380
def _assert_apply_unitary_works_when_axes_transposed(val: Any, *, atol: float = 1e-8) -> None:
337381
"""Tests whether a value's _apply_unitary_ handles out-of-order axes.
338382

cirq-core/cirq/testing/circuit_compare_test.py

+66
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,72 @@ def test_assert_has_diagram():
262262
assert expected_error in ex_info.value.args[0]
263263

264264

265+
def test_assert_has_consistent_apply_channel():
266+
class Correct:
267+
def _apply_channel_(self, args: cirq.ApplyChannelArgs):
268+
args.target_tensor[...] = 0
269+
return args.target_tensor
270+
271+
def _kraus_(self):
272+
return [np.array([[0, 0], [0, 0]])]
273+
274+
def _num_qubits_(self):
275+
return 1
276+
277+
cirq.testing.assert_has_consistent_apply_channel(Correct())
278+
279+
class Wrong:
280+
def _apply_channel_(self, args: cirq.ApplyChannelArgs):
281+
args.target_tensor[...] = 0
282+
return args.target_tensor
283+
284+
def _kraus_(self):
285+
return [np.array([[1, 0], [0, 0]])]
286+
287+
def _num_qubits_(self):
288+
return 1
289+
290+
with pytest.raises(AssertionError):
291+
cirq.testing.assert_has_consistent_apply_channel(Wrong())
292+
293+
class NoNothing:
294+
def _apply_channel_(self, args: cirq.ApplyChannelArgs):
295+
return NotImplemented
296+
297+
def _kraus_(self):
298+
return NotImplemented
299+
300+
def _num_qubits_(self):
301+
return 1
302+
303+
cirq.testing.assert_has_consistent_apply_channel(NoNothing())
304+
305+
class NoKraus:
306+
def _apply_channel_(self, args: cirq.ApplyChannelArgs):
307+
return args.target_tensor
308+
309+
def _kraus_(self):
310+
return NotImplemented
311+
312+
def _num_qubits_(self):
313+
return 1
314+
315+
with pytest.raises(AssertionError):
316+
cirq.testing.assert_has_consistent_apply_channel(NoKraus())
317+
318+
class NoApply:
319+
def _apply_channel_(self, args: cirq.ApplyChannelArgs):
320+
return NotImplemented
321+
322+
def _kraus_(self):
323+
return [np.array([[0, 0], [0, 0]])]
324+
325+
def _num_qubits_(self):
326+
return 1
327+
328+
cirq.testing.assert_has_consistent_apply_channel(NoApply())
329+
330+
265331
def test_assert_has_consistent_apply_unitary():
266332
class IdentityReturningUnalteredWorkspace:
267333
def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray:

cirq-core/cirq/transformers/measurement_transformers.py

+27
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919

2020
from cirq import linalg, ops, protocols, value
21+
from cirq.linalg import transformations
2122
from cirq.transformers import transformer_api, transformer_primitives
2223
from cirq.transformers.synchronize_terminal_measurements import find_terminal_measurements
2324

@@ -341,6 +342,7 @@ def __init__(self, confusion_map: np.ndarray, shape: Sequence[int]):
341342
if not linalg.is_cptp(kraus_ops=kraus):
342343
raise ValueError('Confusion map has invalid probabilities.')
343344
self._shape = tuple(shape)
345+
self._confusion_map = confusion_map.copy()
344346
self._kraus = tuple(kraus)
345347

346348
def _qid_shape_(self) -> Tuple[int, ...]:
@@ -349,6 +351,31 @@ def _qid_shape_(self) -> Tuple[int, ...]:
349351
def _kraus_(self) -> Tuple[np.ndarray, ...]:
350352
return self._kraus
351353

354+
def _apply_channel_(self, args: 'cirq.ApplyChannelArgs'):
355+
configs = []
356+
for i in range(np.prod(self._shape) ** 2):
357+
scale = self._confusion_map.flat[i]
358+
if scale == 0:
359+
continue
360+
index: Any = np.unravel_index(i, self._shape * 2)
361+
slices = []
362+
axis_count = len(args.left_axes)
363+
for j in range(axis_count):
364+
s1 = transformations._SliceConfig(
365+
axis=args.left_axes[j],
366+
source_index=index[j],
367+
target_index=index[j + axis_count],
368+
)
369+
s2 = transformations._SliceConfig(
370+
axis=args.right_axes[j],
371+
source_index=index[j],
372+
target_index=index[j + axis_count],
373+
)
374+
slices.extend([s1, s2])
375+
configs.append(transformations._BuildFromSlicesArgs(slices=tuple(slices), scale=scale))
376+
transformations._build_from_slices(configs, args.target_tensor, out=args.out_buffer)
377+
return args.out_buffer
378+
352379

353380
@value.value_equality
354381
class _ModAdd(ops.ArithmeticGate):

cirq-core/cirq/transformers/measurement_transformers_test.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import sympy
1818

1919
import cirq
20-
from cirq.transformers.measurement_transformers import _mod_add, _MeasurementQid
20+
from cirq.transformers.measurement_transformers import _ConfusionChannel, _MeasurementQid, _mod_add
2121

2222

2323
def assert_equivalent_to_deferred(circuit: cirq.Circuit):
@@ -575,3 +575,17 @@ def test_drop_terminal_nonterminal_error():
575575

576576
with pytest.raises(ValueError, match='Context has `deep=False`'):
577577
_ = cirq.drop_terminal_measurements(circuit, context=None)
578+
579+
580+
def test_confusion_channel_consistency():
581+
two_d_chan = _ConfusionChannel(np.array([[0.5, 0.5], [0.4, 0.6]]), shape=(2,))
582+
cirq.testing.assert_has_consistent_apply_channel(two_d_chan)
583+
three_d_chan = _ConfusionChannel(
584+
np.array([[0.5, 0.3, 0.2], [0.4, 0.5, 0.1], [0, 0, 1]]), shape=(3,)
585+
)
586+
cirq.testing.assert_has_consistent_apply_channel(three_d_chan)
587+
two_q_chan = _ConfusionChannel(
588+
np.array([[0.5, 0.3, 0.1, 0.1], [0.4, 0.5, 0.1, 0], [0, 0, 1, 0], [0, 0, 0.5, 0.5]]),
589+
shape=(2, 2),
590+
)
591+
cirq.testing.assert_has_consistent_apply_channel(two_q_chan)

0 commit comments

Comments
 (0)