Skip to content

Commit 5dd05bf

Browse files
authored
Run product state functions inplace to avoid copies where possible (#6396)
* Run product state merges inplace to avoid copies * rename to merged_state * comment
1 parent e9e12ee commit 5dd05bf

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

cirq-core/cirq/sim/simulation_product_state.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,19 @@ def split_untangled_states(self) -> bool:
6363
return self._split_untangled_states
6464

6565
def create_merged_state(self) -> TSimulationState:
66+
merged_state = self.sim_states[None]
6667
if not self.split_untangled_states:
67-
return self.sim_states[None]
68-
final_args = self.sim_states[None]
69-
for args in set([self.sim_states[k] for k in self.sim_states.keys() if k is not None]):
70-
final_args = final_args.kronecker_product(args)
71-
return final_args.transpose_to_qubit_order(self.qubits)
68+
return merged_state
69+
extra_states = set([self.sim_states[k] for k in self.sim_states.keys() if k is not None])
70+
if not extra_states:
71+
return merged_state
72+
73+
# This comes from a member variable so we need to copy it if we're going to modify inplace
74+
# before returning. We're not running a step currently, so no need to copy buffers.
75+
merged_state = merged_state.copy(deep_copy_buffers=False)
76+
for state in extra_states:
77+
merged_state.kronecker_product(state, inplace=True)
78+
return merged_state.transpose_to_qubit_order(self.qubits, inplace=True)
7279

7380
def _act_on_fallback_(
7481
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
@@ -106,7 +113,7 @@ def _act_on_fallback_(
106113
if op_args_opt is None:
107114
op_args_opt = self.sim_states[q]
108115
elif q not in op_args_opt.qubits:
109-
op_args_opt = op_args_opt.kronecker_product(self.sim_states[q])
116+
op_args_opt.kronecker_product(self.sim_states[q], inplace=True)
110117
op_args = op_args_opt or self.sim_states[None]
111118

112119
# (Backfill the args map with the new value)
@@ -123,7 +130,7 @@ def _act_on_fallback_(
123130
):
124131
for q in qubits:
125132
if op_args.allows_factoring and len(op_args.qubits) > 1:
126-
q_args, op_args = op_args.factor((q,), validate=False)
133+
q_args, _ = op_args.factor((q,), validate=False, inplace=True)
127134
self._sim_states[q] = q_args
128135

129136
# (Backfill the args map with the new value)

0 commit comments

Comments
 (0)