Skip to content

Commit 38775be

Browse files
authored
SimulatorBase independent qubits optimization (quantumlib#4100)
Add optimization that ensures independent qubit sets are simulated independently. This is done by adding join, extract, and reorder methods to ActOnArgs, and updating SimulatorBase with the logic to merge qubit sets when necessary and split them when possible. This optimization is enabled or disabled via a new parameter in the simulator constructors: `split_entangled_qubits`. Currently the PR has this set to True by default, though perhaps it should be disabled by default lest it breaks anything? The MPS simulator does not yet have `extract` defined and thus there's no option to enable this feature in MPS simulator's constructor yet, though nothing prevents this from being added later. The perf boost of this implementation is limited because each StepResult still requires the full product state. It's still a speedup because full product state calculations will only have to occur once per moment rather than once per operation, but not as nice as avoiding full product state calculations entirely. *That* optimization will be available in a subsequent PR that never creates the full product state if possible: StepResults will join the product state only on demand, and sampling will sample each substate independently and zip up the results, avoiding the full state join: The WIP is here https://github.com/daxfohl/Cirq/compare/split...daxfohl:sample?expand=1. I ramped up the number of qubits in the benchmarks to 25 for sparse and 12 for DM: From master: ``` (cirq-py3) dax@DESKTOP-Q5MLJ3J:~/cirq$ time pytest dev_tools/profiling/benchmark_simulators_test.py platform linux -- Python 3.8.5, pytest-5.4.3, py-1.10.0, pluggy-0.13.1 benchmark: 3.2.3 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000) rootdir: /home/dax/cirq plugins: cov-2.5.1, asyncio-0.12.0, benchmark-3.2.3 collected 5 items dev_tools/profiling/benchmark_simulators_test.py ..... [100%] real 0m16.973s user 0m15.754s sys 0m3.862s ``` From split branch (the current PR): ``` (cirq-py3) dax@DESKTOP-Q5MLJ3J:~/cirq$ time pytest dev_tools/profiling/benchmark_simulators_test.py platform linux -- Python 3.8.5, pytest-5.4.3, py-1.10.0, pluggy-0.13.1 benchmark: 3.2.3 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000) rootdir: /home/dax/cirq plugins: cov-2.5.1, asyncio-0.12.0, benchmark-3.2.3 collected 5 items dev_tools/profiling/benchmark_simulators_test.py ..... [100%] real 0m10.073s user 0m9.082s sys 0m3.805s ``` From sample branch (future iteration mentioned above): ``` (cirq-py3) dax@DESKTOP-Q5MLJ3J:~/cirq$ time pytest dev_tools/profiling/benchmark_simulators_test.py platform linux -- Python 3.8.5, pytest-5.4.3, py-1.10.0, pluggy-0.13.1 benchmark: 3.2.3 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000) rootdir: /home/dax/cirq plugins: cov-2.5.1, asyncio-0.12.0, benchmark-3.2.3 collected 5 items dev_tools/profiling/benchmark_simulators_test.py ..... [100%] real 0m2.885s user 0m3.523s sys 0m2.597s ``` Initial PR for quantumlib#3240 Closes quantumlib#882
1 parent a729c41 commit 38775be

27 files changed

+1380
-260
lines changed

cirq/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@
358358

359359
from cirq.sim import (
360360
ActOnArgs,
361+
ActOnArgsContainer,
361362
ActOnCliffordTableauArgs,
362363
ActOnDensityMatrixArgs,
363364
ActOnStabilizerCHFormArgs,
@@ -376,6 +377,7 @@
376377
measure_state_vector,
377378
final_density_matrix,
378379
final_state_vector,
380+
OperationTarget,
379381
sample,
380382
sample_density_matrix,
381383
sample_state_vector,

cirq/contrib/quimb/mps_simulator.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,11 @@ def __init__(
8585
seed=seed,
8686
)
8787

88-
def _create_act_on_args(
88+
def _create_partial_act_on_args(
8989
self,
9090
initial_state: Union[int, 'MPSState'],
9191
qubits: Sequence['cirq.Qid'],
92+
logs: Dict[str, Any],
9293
) -> 'MPSState':
9394
"""Creates MPSState args for simulating the Circuit.
9495
@@ -111,6 +112,7 @@ def _create_act_on_args(
111112
simulation_options=self.simulation_options,
112113
grouping=self.grouping,
113114
initial_state=initial_state,
115+
log_of_measurement_results=logs,
114116
)
115117

116118
def _create_step_result(
@@ -317,6 +319,7 @@ def copy(self) -> 'MPSState':
317319
prng=self.prng,
318320
simulation_options=self.simulation_options,
319321
grouping=self.grouping,
322+
log_of_measurement_results=self.log_of_measurement_results.copy(),
320323
)
321324
state.M = [x.copy() for x in self.M]
322325
state.estimated_gate_error_list = self.estimated_gate_error_list

cirq/linalg/transformations.py

+157
Original file line numberDiff line numberDiff line change
@@ -500,3 +500,160 @@ def to_special(u: np.ndarray) -> np.ndarray:
500500
the special unitary matrix
501501
"""
502502
return u * (np.linalg.det(u) ** (-1 / len(u)))
503+
504+
505+
def state_vector_kronecker_product(
506+
t1: np.ndarray,
507+
t2: np.ndarray,
508+
) -> np.ndarray:
509+
"""Merges two state vectors into a single unified state vector.
510+
511+
The resulting vector's shape will be `t1.shape + t2.shape`.
512+
513+
Args:
514+
t1: The first state vector.
515+
t2: The second state vector.
516+
Returns:
517+
A new state vector representing the unified state.
518+
"""
519+
return np.outer(t1, t2).reshape(t1.shape + t2.shape)
520+
521+
522+
def density_matrix_kronecker_product(
523+
t1: np.ndarray,
524+
t2: np.ndarray,
525+
) -> np.ndarray:
526+
"""Merges two density matrices into a single unified density matrix.
527+
528+
The resulting matrix's shape will be `(t1.shape/2 + t2.shape/2) * 2`. In
529+
other words, if t1 has shape [A,B,C,A,B,C] and t2 has shape [X,Y,Z,X,Y,Z],
530+
the resulting matrix will have shape [A,B,C,X,Y,Z,A,B,C,X,Y,Z].
531+
532+
Args:
533+
t1: The first density matrix.
534+
t2: The second density matrix.
535+
Returns:
536+
A density matrix representing the unified state.
537+
"""
538+
t = state_vector_kronecker_product(t1, t2)
539+
t1_len = len(t1.shape)
540+
t1_dim = int(t1_len / 2)
541+
t2_len = len(t2.shape)
542+
t2_dim = int(t2_len / 2)
543+
shape = t1.shape[:t1_dim] + t2.shape[:t2_dim]
544+
return np.moveaxis(t, range(t1_len, t1_len + t2_dim), range(t1_dim, t1_dim + t2_dim)).reshape(
545+
shape * 2
546+
)
547+
548+
549+
def factor_state_vector(
550+
t: np.ndarray,
551+
axes: Sequence[int],
552+
*,
553+
validate=True,
554+
atol=1e-07,
555+
) -> Tuple[np.ndarray, np.ndarray]:
556+
"""Factors a state vector into two independent state vectors.
557+
558+
This function should only be called on state vectors that are known to be
559+
separable, such as immediately after a measurement or reset operation. It
560+
does not verify that the provided state vector is indeed separable, and
561+
will return nonsense results for vectors representing entangled states.
562+
563+
Args:
564+
t: The state vector to factor.
565+
axes: The axes to factor out.
566+
validate: Perform a validation that the density matrix factors cleanly.
567+
atol: The absolute tolerance for the validation.
568+
Returns:
569+
A tuple with the `(extracted, remainder)` state vectors, where
570+
`extracted` means the sub-state vector which corresponds to the axes
571+
requested, and with the axes in the requested order, and where
572+
`remainder` means the sub-state vector on the remaining axes, in the
573+
same order as the original state vector.
574+
"""
575+
n_axes = len(axes)
576+
t1 = np.moveaxis(t, axes, range(n_axes))
577+
pivot = np.unravel_index(np.abs(t1).argmax(), t1.shape)
578+
slices1 = (slice(None),) * n_axes + pivot[n_axes:]
579+
slices2 = pivot[:n_axes] + (slice(None),) * (t1.ndim - n_axes)
580+
extracted = t1[slices1]
581+
extracted = extracted / np.sum(abs(extracted) ** 2) ** 0.5
582+
remainder = t1[slices2]
583+
remainder = remainder / np.sum(abs(remainder) ** 2) ** 0.5
584+
if validate:
585+
t2 = state_vector_kronecker_product(extracted, remainder)
586+
axes2 = list(axes) + [i for i in range(t1.ndim) if i not in axes]
587+
t3 = transpose_state_vector_to_axis_order(t2, axes2)
588+
if not np.allclose(t3, t, atol=atol):
589+
raise ValueError('The tensor cannot be factored by the requested axes')
590+
return extracted, remainder
591+
592+
593+
def factor_density_matrix(
594+
t: np.ndarray,
595+
axes: Sequence[int],
596+
*,
597+
validate=True,
598+
atol=1e-07,
599+
) -> Tuple[np.ndarray, np.ndarray]:
600+
"""Factors a density matrix into two independent density matrices.
601+
602+
This function should only be called on density matrices that are known to
603+
be separable, such as immediately after a measurement or reset operation.
604+
It does not verify that the provided density matrix is indeed separable,
605+
and will return nonsense results for matrices representing entangled
606+
states.
607+
608+
Args:
609+
t: The density matrix to factor.
610+
axes: The axes to factor out. Only the left axes should be provided.
611+
For example, to extract [C,A] from density matrix of shape
612+
[A,B,C,D,A,B,C,D], `axes` should be [2,0], and the return value
613+
will be two density matrices ([C,A,C,A], [B,D,B,D]).
614+
validate: Perform a validation that the density matrix factors cleanly.
615+
atol: The absolute tolerance for the validation.
616+
Returns:
617+
A tuple with the `(extracted, remainder)` density matrices, where
618+
`extracted` means the sub-matrix which corresponds to the axes
619+
requested, and with the axes in the requested order, and where
620+
`remainder` means the sub-matrix on the remaining axes, in the same
621+
order as the original density matrix.
622+
"""
623+
axes1 = list(axes) + [i + int(t.ndim / 2) for i in axes]
624+
extracted, remainder = factor_state_vector(t, axes1, validate=False)
625+
if validate:
626+
t1 = density_matrix_kronecker_product(extracted, remainder)
627+
axes2 = list(axes) + [i for i in range(int(t.ndim / 2)) if i not in axes]
628+
t2 = transpose_density_matrix_to_axis_order(t1, axes2)
629+
if not np.allclose(t2, t, atol=atol):
630+
raise ValueError('The tensor cannot be factored by the requested axes')
631+
return extracted, remainder
632+
633+
634+
def transpose_state_vector_to_axis_order(t: np.ndarray, axes: Sequence[int]):
635+
"""Transposes the axes of a state vector to a specified order.
636+
637+
Args:
638+
t: The state vector to transpose.
639+
axes: The desired axis order.
640+
Returns:
641+
The transposed state vector.
642+
"""
643+
assert set(axes) == set(range(int(t.ndim))), "All axes must be provided."
644+
return np.moveaxis(t, axes, range(len(axes)))
645+
646+
647+
def transpose_density_matrix_to_axis_order(t: np.ndarray, axes: Sequence[int]):
648+
"""Transposes the axes of a density matrix to a specified order.
649+
650+
Args:
651+
t: The density matrix to transpose.
652+
axes: The desired axis order. Only the left axes should be provided.
653+
For example, to transpose [A,B,C,A,B,C] to [C,B,A,C,B,A], `axes`
654+
should be [2,1,0].
655+
Returns:
656+
The transposed density matrix.
657+
"""
658+
axes = list(axes) + [i + len(axes) for i in axes]
659+
return transpose_state_vector_to_axis_order(t, axes)

cirq/protocols/act_on_protocol_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _act_on_fallback_(self, action, qubits, allow_decompose):
8282
return True
8383

8484
args = Args()
85-
args.qubits = tuple(cirq.LineQubit.range(3))
85+
args._qubits = tuple(cirq.LineQubit.range(3))
8686
with cirq.testing.assert_deprecated(
8787
"ActOnArgs.axes", "Use `protocols.act_on` instead.", deadline="v0.13"
8888
):

cirq/protocols/json_test_data/spec.py

+2
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,15 @@
8383
'TwoQubitInteractionHeatmap',
8484
# Intermediate states with work buffers and unknown external prng guts.
8585
'ActOnArgs',
86+
'ActOnArgsContainer',
8687
'ActOnCliffordTableauArgs',
8788
'ActOnDensityMatrixArgs',
8889
'ActOnStabilizerCHFormArgs',
8990
'ActOnStateVectorArgs',
9091
'ApplyChannelArgs',
9192
'ApplyMixtureArgs',
9293
'ApplyUnitaryArgs',
94+
'OperationTarget',
9395
# Circuit optimizers are function-like. Only attributes
9496
# are ignore_failures, tolerance, and other feature flags
9597
'AlignLeft',

cirq/sim/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
ActOnArgs,
2020
)
2121

22+
from cirq.sim.act_on_args_container import (
23+
ActOnArgsContainer,
24+
)
25+
2226
from cirq.sim.act_on_density_matrix_args import (
2327
ActOnDensityMatrixArgs,
2428
)
@@ -39,6 +43,8 @@
3943
DensityMatrixTrialResult,
4044
)
4145

46+
from cirq.sim.operation_target import OperationTarget
47+
4248
from cirq.sim.mux import (
4349
CIRCUIT_LIKE,
4450
final_density_matrix,

cirq/sim/act_on_args.py

+65-5
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,39 @@
1313
# limitations under the License.
1414
"""Objects and methods for acting efficiently on a state tensor."""
1515
import abc
16-
from typing import Any, Dict, List, TypeVar, TYPE_CHECKING, Sequence, Tuple, Iterable
16+
from typing import (
17+
Any,
18+
Iterable,
19+
Dict,
20+
List,
21+
TypeVar,
22+
TYPE_CHECKING,
23+
Sequence,
24+
Tuple,
25+
cast,
26+
Optional,
27+
Iterator,
28+
)
1729

1830
import numpy as np
1931

2032
from cirq import protocols
2133
from cirq._compat import deprecated
2234
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
35+
from cirq.sim.operation_target import OperationTarget
2336

2437
TSelf = TypeVar('TSelf', bound='ActOnArgs')
2538

2639
if TYPE_CHECKING:
2740
import cirq
2841

2942

30-
class ActOnArgs:
43+
class ActOnArgs(OperationTarget[TSelf]):
3144
"""State and context for an operation acting on a state tensor."""
3245

3346
def __init__(
3447
self,
35-
prng: np.random.RandomState,
48+
prng: np.random.RandomState = None,
3649
qubits: Sequence['cirq.Qid'] = None,
3750
axes: Iterable[int] = None,
3851
log_of_measurement_results: Dict[str, Any] = None,
@@ -50,17 +63,19 @@ def __init__(
5063
being recorded into. Edit it easily by calling
5164
`ActOnStateVectorArgs.record_measurement_result`.
5265
"""
66+
if prng is None:
67+
prng = cast(np.random.RandomState, np.random)
5368
if qubits is None:
5469
qubits = ()
5570
if axes is None:
5671
axes = ()
5772
if log_of_measurement_results is None:
5873
log_of_measurement_results = {}
59-
self.qubits = tuple(qubits)
74+
self._qubits = tuple(qubits)
6075
self.qubit_map = {q: i for i, q in enumerate(self.qubits)}
6176
self._axes = tuple(axes)
6277
self.prng = prng
63-
self.log_of_measurement_results = log_of_measurement_results
78+
self._log_of_measurement_results = log_of_measurement_results
6479

6580
def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[bool]):
6681
"""Adds a measurement result to the log.
@@ -90,6 +105,51 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
90105
def copy(self: TSelf) -> TSelf:
91106
"""Creates a copy of the object."""
92107

108+
def create_merged_state(self: TSelf) -> TSelf:
109+
"""Creates a final merged state."""
110+
return self
111+
112+
def apply_operation(self, op: 'cirq.Operation'):
113+
"""Applies the operation to the state."""
114+
protocols.act_on(op, self)
115+
116+
def kronecker_product(self: TSelf, other: TSelf) -> TSelf:
117+
"""Joins two state spaces together."""
118+
raise NotImplementedError()
119+
120+
def factor(
121+
self: TSelf,
122+
qubits: Sequence['cirq.Qid'],
123+
*,
124+
validate=True,
125+
atol=1e-07,
126+
) -> Tuple[TSelf, TSelf]:
127+
"""Splits two state spaces after a measurement or reset."""
128+
raise NotImplementedError()
129+
130+
def transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid']) -> TSelf:
131+
"""Physically reindexes the state by the new basis."""
132+
raise NotImplementedError()
133+
134+
@property
135+
def log_of_measurement_results(self) -> Dict[str, Any]:
136+
return self._log_of_measurement_results
137+
138+
@property
139+
def qubits(self) -> Tuple['cirq.Qid', ...]:
140+
return self._qubits
141+
142+
def __getitem__(self: TSelf, item: Optional['cirq.Qid']) -> TSelf:
143+
if item not in self.qubit_map:
144+
raise IndexError(f'{item} not in {self.qubits}')
145+
return self
146+
147+
def __len__(self) -> int:
148+
return len(self.qubits)
149+
150+
def __iter__(self) -> Iterator[Optional['cirq.Qid']]:
151+
return iter(self.qubits)
152+
93153
@abc.abstractmethod
94154
def _act_on_fallback_(self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool):
95155
"""Handles the act_on protocol fallback implementation."""

0 commit comments

Comments
 (0)