-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathxmon_simulator.py
425 lines (338 loc) · 16 KB
/
xmon_simulator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wavefunction simulator specialized to Google's xmon gate set."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import math
import multiprocessing
from typing import Any, Dict, List, Tuple
import numpy as np
from cirq.sim.google import mem_manager
class XmonSimulator(object):
"""A wave function simulator for quantum circuits with the xmon gate set.
Xmons have a natural gate set made up of
* Single qubit phase gates, exp(i t Z)
* Single qubit XY gates, exp(i t (cos(theta) X + sin(theta)Y)
* Two qubit phase gates exp(i t |11><11|)
This simulator will do sharded simulation of the wave function using
python's multiprocessing module.
Simulator should be used like a context manager:
with XmonSimulator(num_qubits=3) as s:
s.simulate_z(1, 0.25)
s.simulate_xy(2, 0.25, 0.25)
...
"""
def __init__(self,
num_qubits: int,
num_prefix_qubits: int = None,
initial_state: int = 0,
shard_for_small_num_qubits: bool = True):
"""Construct a new Simulator.
Args:
num_qubits: The number of qubits to simulate.
num_prefix_qubits: The wavefunction of the qubits is sharded into
(2 ** num_prefix_qubits) parts. If this is None, then this will
shard over the nearest power of two below the cpu count. If less
than 10 qubits are being simulated then no sharding is done,
depending on whether the shard_for_small_num_qubits is set or not.
initial_state: The initial state to start from, expressed as an
integer of the computational basis. Integer to bitwise indices is
little endian.
shard_for_small_num_qubits: Whether or not to shard for strictly less
than 10 qubits, default to True. Useful to turn off for testing.
"""
self._num_qubits = num_qubits
if num_prefix_qubits is None:
num_prefix_qubits = int(math.log(multiprocessing.cpu_count(), 2))
if num_prefix_qubits > num_qubits:
num_prefix_qubits = num_qubits
if shard_for_small_num_qubits and num_qubits < 10:
num_prefix_qubits = 0
self._num_prefix_qubits = num_prefix_qubits
# Each shard is of a dimension equal to 2 ** num_shard_qubits.
self._num_shard_qubits = self._num_qubits - self._num_prefix_qubits
self._num_shards = 2 ** self._num_prefix_qubits
self._shard_size = 2 ** self._num_shard_qubits
# TODO(dabacon): This could be parallelized.
self._init_shared_mem(initial_state)
def _init_shared_mem(self, initial_state: int):
self._shared_mem_dict = {}
self.init_z_vects()
self._init_scratch()
self._init_state(initial_state)
def init_z_vects(self):
"""Initializes bitwise vectors which is precomputed in shared memory.
There are two types of vectors here, a zero one vectors and a pm
(plus/minus) vectors. The pm vectors have rows that are Pauli Z
operators acting on the all ones vector. The column-th row corresponds
to the Pauli Z acting on the columnth-qubit. Example for three shard
qubits:
[[1, -1, 1, -1, 1, -1, 1, -1],
[1, 1, -1, -1, 1, 1, -1, -1],
[1, 1, 1, 1, -1, -1, -1, -1]]
The zero one vectors are the pm vectors with 1 replacing -1 and 0
replacing 1.
There are number of shard qubit zero one vectors and each of these is of
size equal to the shard size. For the zero one vectors, the ith one of
these vectors has a kth index value that is equal to 1 if the ith bit
of k is set and zero otherwise. The vector directly encode the
little-endian binary deigits of its index in the list:
v[j][i] = (i >> j) & 1. For the pm vectors, the ith one of these vectors
has a kth index value that is equal to -1 if the ith bit of k is set
and 1 otherwise.
"""
shard_size = 2 ** self._num_shard_qubits
a, b = np.indices((shard_size, self._num_shard_qubits))
a >>= b
a &= 1
zero_one_vects = np.ascontiguousarray(a.transpose())
zero_one_vects_handle = mem_manager.SharedMemManager.create_array(
zero_one_vects)
self._shared_mem_dict['zero_one_vects_handle'] = zero_one_vects_handle
pm_vects = 1 - 2 * zero_one_vects
pm_vects_handle = mem_manager.SharedMemManager.create_array(pm_vects)
self._shared_mem_dict['pm_vects_handle'] = pm_vects_handle
def _init_scratch(self):
"""Initializes a scratch pad equal in size to the wavefunction."""
scratch = np.zeros((self._num_shards, self._shard_size),
dtype=np.complex64)
scratch_handle = mem_manager.SharedMemManager.create_array(
scratch.view(dtype=np.float32))
self._shared_mem_dict['scratch_handle'] = scratch_handle
def _init_state(self, initial_state: int):
"""Initializes a the shard wavefunction and sets the initial state."""
state = np.zeros((self._num_shards, self._shard_size),
dtype=np.complex64)
shard_num = initial_state // self._shard_size
state[shard_num][initial_state % self._shard_size] = 1.0
state_handle = mem_manager.SharedMemManager.create_array(
state.view(dtype=np.float32))
self._shared_mem_dict['state_handle'] = state_handle
def __del__(self):
for handle in self._shared_mem_dict.values():
mem_manager.SharedMemManager.free_array(handle)
def __enter__(self):
self._pool = multiprocessing.Pool(processes=self._num_shards)
return self
def __exit__(self, *args):
self._pool.close()
self._pool.join()
def _shard_num_args(
self, constant_dict: Dict[str, Any] = None) -> List[Dict[str, Any]]:
"""Helper that returns a list of dicts including a num_shard entry.
The dict for each entry also includes shared_mem_dict, the number of
shards, the number of shard qubits, and the supplied constant dict.
Args:
constant_dict: Dictionary that will be updated to every element of the
returned list of dictionaries.
Returns:
A list of dictionaries. Each dictionary is constant except for the
'shard_num' key which ranges from 0 to number of shards - 1. Included
keys are 'num_shards' and 'num_shard_qubits' along with all the
keys in constant_dict.
"""
args = []
for shard_num in range(self._num_shards):
append_dict = dict(constant_dict) if constant_dict else {}
append_dict['shard_num'] = shard_num
append_dict['num_shards'] = self._num_shards
append_dict['num_shard_qubits'] = self._num_shard_qubits
append_dict.update(self._shared_mem_dict)
args.append(append_dict)
return args
@property
def current_state(self):
"""Returns the current wavefunction."""
return np.array(self._pool.map(_state_shard,
self._shard_num_args())).flatten()
def simulate_phases(self, phase_map: Dict[Tuple[int], float]):
"""Simulate a set of phase gates on the xmon architecture.
Args:
phase_map: A map from a tuple of indices to a value, one for each
phase gate being simulated. If the tuple key has one index, then
this is a Z phase gate on the index-th qubit with a rotation angle
of 2 pi times the value of the map. If the tuple key has two
indices, then this is a |11> phasing gate, acting on the qubits at
the two indices, and a rotation angle of 2 pi times the value of
the map.
"""
self._pool.map(_clear_scratch, self._shard_num_args())
# Iterate over the map of phase data.
for indices, turns in phase_map.items():
args = self._shard_num_args({'indices': indices, 'turns': turns})
if len(indices) == 1:
self._pool.map(_single_qubit_accumulate_into_scratch, args)
elif len(indices) == 2:
self._pool.map(_two_qubit_accumulate_into_scratch, args)
# Exponentiate the phases and add them into the state.
self._pool.map(_apply_scratch_as_phase, self._shard_num_args())
def simulate_xy(self, index: int, turns: float, rotation_axis_turns: float):
"""Simulate a single qubit XY rotation gate.
The gate simulated is cos(2 pi turns) I + i sin (2 pi turns) *
(cos (2 pi rotation_axis_turns) X + sin(2 pi rotation_axis_turns))
Args:
index: The qubit to act on.
turns: The amount of the overall rotation, see formula above.
rotation_axis_turns: The angle between the pauli X and Y operators,
see the formula above.
"""
args = self._shard_num_args({
'index': index,
'turns': turns,
'rotation_axis_turns': rotation_axis_turns
})
if index >= self._num_shard_qubits:
# XY gate spans shards.
self._pool.map(_clear_scratch, args)
self._pool.map(_xy_between_shards, args)
self._pool.map(_copy_scratch_to_state, args)
else:
# XY gate is within a shard.
self._pool.map(_xy_within_shard, args)
def simulate_measurement(self, index: int) -> bool:
"""Simulates a single qubit measurement in the computational basis.
Args:
index: Which qubit is measured.
Returns:
True iff the measurement result corresponds to the |1> state.
"""
args = self._shard_num_args({'index': index})
prob_one = np.sum(self._pool.map(_one_prob_per_shard, args))
result = (np.random.random() <= prob_one)
args = self._shard_num_args({
'index': index,
'result': result,
'prob_one': prob_one
})
self._pool.map(_collapse_state, args)
return result
def _state_shard(args: Dict[str, Any]) -> np.ndarray:
state_handle = args['state_handle']
return mem_manager.SharedMemManager.get_array(state_handle).view(
dtype=np.complex64)[args['shard_num']]
def _scratch_shard(args: Dict[str, Any]) -> np.ndarray:
scratch_handle = args['scratch_handle']
return mem_manager.SharedMemManager.get_array(scratch_handle).view(
dtype=np.complex64)[args['shard_num']]
def _pm_vects(args: Dict[str, Any]) -> np.ndarray:
return mem_manager.SharedMemManager.get_array(args['pm_vects_handle'])
def _zero_one_vects(args: Dict[str, Any]) -> np.ndarray:
return mem_manager.SharedMemManager.get_array(args['zero_one_vects_handle'])
def as_raw_array(arr: np.ndarray) -> multiprocessing.RawArray:
"""Returns a multiprocessing.RawArray for a given numpy array."""
c_arr = np.ctypeslib.as_ctypes(arr)
# pylint: disable=protected-access
return multiprocessing.RawArray(c_arr._type_, c_arr)
def _kth_bit(x: int, k: int) -> int:
"""Returns 1 if the kth bit of x is set, 0 otherwise."""
return (x >> k) & 1
def _clear_scratch(args: Dict[str, Any]):
"""Sets all of the scratch shard to zero."""
_scratch_shard(args).fill(0)
def _single_qubit_accumulate_into_scratch(args: Dict[str, Any]):
"""Accumuates single qubit phase gates into the scratch shards."""
index = args['indices'][0]
shard_num = args['shard_num']
turns = args['turns']
num_shard_qubits = args['num_shard_qubits']
scratch = _scratch_shard(args)
if index >= num_shard_qubits:
# Acts on prefix qubits.
sign = 1 - 2 * _kth_bit(shard_num, index - num_shard_qubits)
scratch += turns * sign
else:
# Acts on shard qubits.
scratch += turns * _pm_vects(args)[index]
def _one_projector(args: Dict[str, Any], index: int) -> np.ndarray:
"""Returns a projector onto the |1> subspace of the index-th qubit."""
num_shard_qubits = args['num_shard_qubits']
shard_num = args['shard_num']
if index >= num_shard_qubits:
return _kth_bit(shard_num, index - num_shard_qubits)
return _zero_one_vects(args)[index]
def _two_qubit_accumulate_into_scratch(args: Dict[str, Any]):
"""Accumulates two qubit phase gates into the scratch shards."""
index0, index1 = args['indices']
turns = args['turns']
scratch = _scratch_shard(args)
sign = 1 - 2 * _one_projector(args, index0) * _one_projector(args, index1)
scratch += turns * sign
def _apply_scratch_as_phase(args: Dict[str, Any]):
"""Takes scratch shards and applies them as exponentiated phase to state."""
state = _state_shard(args)
state *= np.exp((2j * np.pi) * _scratch_shard(args))
def _xy_within_shard(args: Dict[str, Any]):
"""Applies an XY gate when the gate acts only within a shard."""
index = args['index']
turns = args['turns']
rotation_axis_turns = args['rotation_axis_turns']
state = _state_shard(args)
pm_vect = _pm_vects(args)[index]
num_shard_qubits = args['num_shard_qubits']
shard_size = 2 ** num_shard_qubits
reshape_tuple = (2 ** (num_shard_qubits - 1 - index), 2, 2 ** index)
perm_state = np.reshape(
np.reshape(state, reshape_tuple)[:, ::-1, :], shard_size)
cos = np.cos(2 * np.pi * turns)
sin = np.sin(2 * np.pi * turns)
cos_axis = np.cos(2 * np.pi * rotation_axis_turns)
sin_axis = np.sin(2 * np.pi * rotation_axis_turns)
new_state = cos * state + 1j * sin * perm_state * (
cos_axis - 1j * sin_axis * pm_vect)
np.copyto(state, new_state)
def _xy_between_shards(args: Dict[str, Any]):
"""Applies an XY gate when the gate acts between shards."""
shard_num = args['shard_num']
state = _state_shard(args)
num_shard_qubits = args['num_shard_qubits']
index = args['index']
turns = args['turns']
rotation_axis_turns = args['rotation_axis_turns']
perm_index = shard_num ^ (1 << (index - num_shard_qubits))
perm_state = mem_manager.SharedMemManager.get_array(
args['state_handle']).view(np.complex64)[perm_index]
cos = np.cos(2 * np.pi * turns)
sin = np.sin(2 * np.pi * turns)
cos_axis = np.cos(2 * np.pi * rotation_axis_turns)
sin_axis = np.sin(2 * np.pi * rotation_axis_turns)
scratch = _scratch_shard(args)
z_op = (1 - 2 * _kth_bit(shard_num, index - num_shard_qubits))
np.copyto(scratch, state * cos + 1j * sin * perm_state *
(cos_axis - 1j * sin_axis * z_op))
def _copy_scratch_to_state(args: Dict[str, Any]):
"""Copes scratch shards to state shards."""
np.copyto(_state_shard(args), _scratch_shard(args))
def _one_prob_per_shard(args: Dict[str, Any]) -> float:
"""Returns the probability of getting a one measurement on a state shard."""
index = args['index']
state = _state_shard(args) * _one_projector(args, index)
norm = np.linalg.norm(state)
return norm * norm
def _collapse_state(args: Dict[str, Any]):
"""Projects state shards onto the appropriate post measurement state.
This function makes no assumptions about the interpretation of quantum
theory.
Args:
args: The args from shard_num_args.
"""
index = args['index']
result = args['result']
prob_one = args['prob_one']
state = _state_shard(args)
normalization = np.sqrt(prob_one if result else 1 - prob_one)
state *= (_one_projector(args, index) * result
+ (1 - _one_projector(args, index)) * (1 - result))
state /= normalization