Skip to content

Commit 3807d6f

Browse files
authored
DensePauliString and MutableDensePauliString docs and inconsistencies fixes (quantumlib#5624)
1 parent 1020c82 commit 3807d6f

File tree

2 files changed

+135
-66
lines changed

2 files changed

+135
-66
lines changed

Diff for: cirq/ops/dense_pauli_string.py

+114-54
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Iterator,
2525
List,
2626
Optional,
27+
overload,
2728
Sequence,
2829
Tuple,
2930
Type,
@@ -45,9 +46,8 @@
4546

4647
# Order is important! Index equals numeric value.
4748
PAULI_CHARS = 'IXYZ'
48-
PAULI_GATES: List['cirq.Gate'] = [
49-
# mypy false positive "Cannot determine type of 'I'"
50-
identity.I, # type: ignore
49+
PAULI_GATES: List[Union['cirq.Pauli', 'cirq.IdentityGate']] = [
50+
identity.I,
5151
pauli_gates.X,
5252
pauli_gates.Y,
5353
pauli_gates.Z,
@@ -58,7 +58,27 @@
5858

5959
@value.value_equality(approximate=True, distinct_child_types=True)
6060
class BaseDensePauliString(raw_types.Gate, metaclass=abc.ABCMeta):
61-
"""Parent class for `cirq.DensePauliString` and `cirq.MutableDensePauliString`."""
61+
"""Parent class for `cirq.DensePauliString` and `cirq.MutableDensePauliString`.
62+
63+
`cirq.BaseDensePauliString` is an abstract base class, which is used to implement
64+
`cirq.DensePauliString` and `cirq.MutableDensePauliString`. The non-mutable version
65+
is used as the corresponding gate for `cirq.PauliString` operation and the mutable
66+
version is mainly used for efficiently manipulating dense pauli strings.
67+
68+
See the docstrings of `cirq.DensePauliString` and `cirq.MutableDensePauliString` for more
69+
details.
70+
71+
Examples:
72+
>>> print(cirq.DensePauliString('XXIY'))
73+
+XXIY
74+
75+
>>> print(cirq.MutableDensePauliString('IZII', coefficient=-1))
76+
-IZII (mutable)
77+
78+
>>> print(cirq.DensePauliString([0, 1, 2, 3],
79+
... coefficient=sympy.Symbol('t')))
80+
t*IXYZ
81+
"""
6282

6383
I_VAL = 0
6484
X_VAL = 1
@@ -69,7 +89,7 @@ def __init__(
6989
self,
7090
pauli_mask: Union[Iterable['cirq.PAULI_GATE_LIKE'], np.ndarray],
7191
*,
72-
coefficient: Union[sympy.Expr, int, float, 'cirq.TParamValComplex'] = 1,
92+
coefficient: 'cirq.TParamValComplex' = 1,
7393
):
7494
"""Initializes a new dense pauli string.
7595
@@ -84,17 +104,6 @@ def __init__(
84104
instead of being copied.
85105
coefficient: A complex number. Usually +1, -1, 1j, or -1j but other
86106
values are supported.
87-
88-
Examples:
89-
>>> print(cirq.DensePauliString('XXIY'))
90-
+XXIY
91-
92-
>>> print(cirq.MutableDensePauliString('IZII', coefficient=-1))
93-
-IZII (mutable)
94-
95-
>>> print(cirq.DensePauliString([0, 1, 2, 3],
96-
... coefficient=sympy.Symbol('t')))
97-
t*IXYZ
98107
"""
99108
self._pauli_mask = _as_pauli_mask(pauli_mask)
100109
self._coefficient: Union[complex, sympy.Expr] = (
@@ -106,10 +115,12 @@ def __init__(
106115

107116
@property
108117
def pauli_mask(self) -> np.ndarray:
118+
"""A 1-dimensional uint8 numpy array giving a specification of Pauli gates to use."""
109119
return self._pauli_mask
110120

111121
@property
112122
def coefficient(self) -> Union[sympy.Expr, complex]:
123+
"""A complex coefficient or symbol."""
113124
return self._coefficient
114125

115126
def _json_dict_(self) -> Dict[str, Any]:
@@ -147,28 +158,31 @@ def eye(cls: Type[TCls], length: int) -> TCls:
147158
concrete_cls = cast(Callable, DensePauliString if cls is BaseDensePauliString else cls)
148159
return concrete_cls(pauli_mask=np.zeros(length, dtype=np.uint8))
149160

150-
def _num_qubits_(self):
161+
def _num_qubits_(self) -> int:
151162
return len(self)
152163

153-
def _has_unitary_(self):
164+
def _has_unitary_(self) -> bool:
154165
return not self._is_parameterized_() and (abs(abs(self.coefficient) - 1) < 1e-8)
155166

156-
def _unitary_(self):
167+
def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
157168
if not self._has_unitary_():
158169
return NotImplemented
159170
return self.coefficient * linalg.kron(
160171
*[protocols.unitary(PAULI_GATES[p]) for p in self.pauli_mask]
161172
)
162173

163-
def _apply_unitary_(self, args):
174+
def _apply_unitary_(self, args) -> Union[np.ndarray, None, NotImplementedType]:
164175
if not self._has_unitary_():
165176
return NotImplemented
166177
from cirq import devices
167178

168179
qubits = devices.LineQubit.range(len(self))
169-
return protocols.apply_unitaries(self._decompose_(qubits), qubits, args)
180+
decomposed_ops = cast(Iterable['cirq.OP_TREE'], self._decompose_(qubits))
181+
return protocols.apply_unitaries(decomposed_ops, qubits, args)
170182

171-
def _decompose_(self, qubits):
183+
def _decompose_(
184+
self, qubits: Sequence['cirq.Qid']
185+
) -> Union[NotImplementedType, 'cirq.OP_TREE']:
172186
if not self._has_unitary_():
173187
return NotImplemented
174188
result = [PAULI_GATES[p].on(q) for p, q in zip(self.pauli_mask, qubits) if p]
@@ -190,18 +204,27 @@ def _resolve_parameters_(self: TCls, resolver: 'cirq.ParamResolver', recursive:
190204
def __pos__(self):
191205
return self
192206

193-
def __pow__(self, power):
207+
def __pow__(self: TCls, power: Union[int, float]) -> Union[NotImplementedType, TCls]:
208+
concrete_class = type(self)
194209
if isinstance(power, int):
195210
i_group = [1, +1j, -1, -1j]
196211
if self.coefficient in i_group:
197-
coef = i_group[i_group.index(self.coefficient) * power % 4]
212+
coef = i_group[i_group.index(cast(int, self.coefficient)) * power % 4]
198213
else:
199214
coef = self.coefficient**power
200215
if power % 2 == 0:
201-
return coef * DensePauliString.eye(len(self))
202-
return DensePauliString(coefficient=coef, pauli_mask=self.pauli_mask)
216+
return concrete_class.eye(len(self)).__mul__(coef)
217+
return concrete_class(coefficient=coef, pauli_mask=self.pauli_mask)
203218
return NotImplemented
204219

220+
@overload
221+
def __getitem__(self: TCls, item: int) -> Union['cirq.Pauli', 'cirq.IdentityGate']:
222+
pass
223+
224+
@overload
225+
def __getitem__(self: TCls, item: slice) -> TCls:
226+
pass
227+
205228
def __getitem__(self, item):
206229
if isinstance(item, int):
207230
return PAULI_GATES[self.pauli_mask[item]]
@@ -211,15 +234,15 @@ def __getitem__(self, item):
211234

212235
raise TypeError(f'indices must be integers or slices, not {type(item)}')
213236

214-
def __iter__(self) -> Iterator['cirq.Gate']:
237+
def __iter__(self) -> Iterator[Union['cirq.Pauli', 'cirq.IdentityGate']]:
215238
for i in range(len(self)):
216239
yield self[i]
217240

218-
def __len__(self):
241+
def __len__(self) -> int:
219242
return len(self.pauli_mask)
220243

221244
def __neg__(self):
222-
return DensePauliString(coefficient=-self.coefficient, pauli_mask=self.pauli_mask)
245+
return type(self)(coefficient=-self.coefficient, pauli_mask=self.pauli_mask)
223246

224247
def __truediv__(self, other):
225248
if isinstance(other, (sympy.Basic, numbers.Number)):
@@ -228,7 +251,10 @@ def __truediv__(self, other):
228251
return NotImplemented
229252

230253
def __mul__(self, other):
254+
concrete_class = type(self)
231255
if isinstance(other, BaseDensePauliString):
256+
if isinstance(other, MutableDensePauliString):
257+
concrete_class = MutableDensePauliString
232258
max_len = max(len(self.pauli_mask), len(other.pauli_mask))
233259
min_len = min(len(self.pauli_mask), len(other.pauli_mask))
234260
new_mask = np.zeros(max_len, dtype=np.uint8)
@@ -237,22 +263,22 @@ def __mul__(self, other):
237263
tweak = _vectorized_pauli_mul_phase(
238264
self.pauli_mask[:min_len], other.pauli_mask[:min_len]
239265
)
240-
return DensePauliString(
266+
return concrete_class(
241267
pauli_mask=new_mask, coefficient=self.coefficient * other.coefficient * tweak
242268
)
243269

244270
if isinstance(other, (sympy.Basic, numbers.Number)):
245271
new_coef = protocols.mul(self.coefficient, other, default=None)
246272
if new_coef is None:
247273
return NotImplemented
248-
return DensePauliString(pauli_mask=self.pauli_mask, coefficient=new_coef)
274+
return concrete_class(pauli_mask=self.pauli_mask, coefficient=new_coef)
249275

250276
split = _attempt_value_to_pauli_index(other)
251277
if split is not None:
252278
p, i = split
253279
mask = np.copy(self.pauli_mask)
254280
mask[i] ^= p
255-
return DensePauliString(
281+
return concrete_class(
256282
pauli_mask=mask,
257283
coefficient=self.coefficient * _vectorized_pauli_mul_phase(self.pauli_mask[i], p),
258284
)
@@ -268,14 +294,14 @@ def __rmul__(self, other):
268294
p, i = split
269295
mask = np.copy(self.pauli_mask)
270296
mask[i] ^= p
271-
return DensePauliString(
297+
return type(self)(
272298
pauli_mask=mask,
273299
coefficient=self.coefficient * _vectorized_pauli_mul_phase(p, self.pauli_mask[i]),
274300
)
275301

276302
return NotImplemented
277303

278-
def tensor_product(self, other: 'BaseDensePauliString') -> 'DensePauliString':
304+
def tensor_product(self: TCls, other: 'BaseDensePauliString') -> TCls:
279305
"""Concatenates dense pauli strings and multiplies their coefficients.
280306
281307
Args:
@@ -285,13 +311,13 @@ def tensor_product(self, other: 'BaseDensePauliString') -> 'DensePauliString':
285311
A dense pauli string with the concatenation of the paulis from the
286312
two input pauli strings, and the product of their coefficients.
287313
"""
288-
return DensePauliString(
314+
return type(self)(
289315
coefficient=self.coefficient * other.coefficient,
290316
pauli_mask=np.concatenate([self.pauli_mask, other.pauli_mask]),
291317
)
292318

293-
def __abs__(self):
294-
return DensePauliString(coefficient=abs(self.coefficient), pauli_mask=self.pauli_mask)
319+
def __abs__(self: TCls) -> TCls:
320+
return type(self)(coefficient=abs(self.coefficient), pauli_mask=self.pauli_mask)
295321

296322
def on(self, *qubits: 'cirq.Qid') -> 'cirq.PauliString':
297323
return self.sparse(qubits)
@@ -392,17 +418,36 @@ def copy(
392418
class DensePauliString(BaseDensePauliString):
393419
"""An immutable string of Paulis, like `XIXY`, with a coefficient.
394420
395-
This represents a Pauli operator acting on qubits.
421+
A `DensePauliString` represents a multi-qubit pauli operator, i.e. a tensor product of single
422+
qubits Pauli gates (including the `cirq.IdentityGate`), each of which would act on a
423+
different qubit. When applied on qubits, a `DensePauliString` results in `cirq.PauliString`
424+
as an operation.
425+
426+
Note that `cirq.PauliString` only stores a tensor product of non-identity `cirq.Pauli`
427+
operations whereas `cirq.DensePauliString` also supports storing the `cirq.IdentityGate`.
428+
429+
For example,
430+
431+
>>> dps = cirq.DensePauliString('XXIY')
432+
>>> print(dps) # 4 qubit pauli operator with 'X' on first 2 qubits, 'I' on 3rd and 'Y' on 4th.
433+
+XXIY
434+
>>> ps = dps.on(*cirq.LineQubit.range(4)) # When applied on qubits, we get a `cirq.PauliString`.
435+
>>> print(ps) # Note that `cirq.PauliString` only preserves non-identity operations.
436+
X(q(0))*X(q(1))*Y(q(3))
396437
397-
For example, `cirq.MutableDensePauliString("XXY")` represents a
398-
three qubit operation that acts with `X` on the first two qubits, and
399-
`Y` on the last.
438+
This can optionally take a coefficient, for example:
400439
401-
This can optionally take a coefficient, for example,
402-
`cirq.MutableDensePauliString("XX", 3)`, which represents 3 times
403-
the operator acting on X on two qubits.
440+
>>> dps = cirq.DensePauliString("XX", coefficient=3)
441+
>>> print(dps) # Represents 3 times the operator XX acting on two qubits.
442+
(3+0j)*XX
443+
>>> print(dps.on(*cirq.LineQubit.range(2))) # Coefficient is propagated to `cirq.PauliString`.
444+
(3+0j)*X(q(0))*X(q(1))
404445
405-
If the coefficient has magnitude of 1, then this is also a `cirq.Gate`.
446+
If the coefficient has magnitude of 1, the resulting operator is a unitary and thus is
447+
also a `cirq.Gate`.
448+
449+
Note that `DensePauliString` is an immutable object. If you need a mutable version of
450+
dense pauli strings, see `cirq.MutableDensePauliString`.
406451
"""
407452

408453
def frozen(self) -> 'DensePauliString':
@@ -425,19 +470,34 @@ def copy(
425470
class MutableDensePauliString(BaseDensePauliString):
426471
"""A mutable string of Paulis, like `XIXY`, with a coefficient.
427472
428-
This represents a Pauli operator acting on qubits.
473+
`cirq.MutableDensePauliString` is a mutable version of `cirq.DensePauliString`.
474+
It exists mainly to help mutate dense pauli strings efficiently, instead of always creating
475+
a copy, and then converting back to a frozen `cirq.DensePauliString` representation.
429476
430-
For example, `cirq.MutableDensePauliString("XXY")` represents a
431-
three qubit operation that acts with `X` on the first two qubits, and
432-
`Y` on the last.
477+
For example:
433478
434-
This can optionally take a coefficient, for example,
435-
`cirq.MutableDensePauliString("XX", 3)`, which represents 3 times
436-
the operator acting on X on two qubits.
479+
>>> mutable_dps = cirq.MutableDensePauliString('XXZZ')
480+
>>> mutable_dps[:2] = 'YY' # `cirq.MutableDensePauliString` supports item assignment.
481+
>>> print(mutable_dps)
482+
+YYZZ (mutable)
437483
438-
If the coefficient has magnitude of 1, then this is also a `cirq.Gate`.
484+
See docstrings of `cirq.DensePauliString` for more details on dense pauli strings.
439485
"""
440486

487+
@overload
488+
def __setitem__(
489+
self: 'MutableDensePauliString', key: int, value: 'cirq.PAULI_GATE_LIKE'
490+
) -> 'MutableDensePauliString':
491+
pass
492+
493+
@overload
494+
def __setitem__(
495+
self: 'MutableDensePauliString',
496+
key: slice,
497+
value: Union[Iterable['cirq.PAULI_GATE_LIKE'], np.ndarray, BaseDensePauliString],
498+
) -> 'MutableDensePauliString':
499+
pass
500+
441501
def __setitem__(self, key, value):
442502
if isinstance(key, int):
443503
self.pauli_mask[key] = _pauli_index(value)
@@ -557,7 +617,7 @@ def _as_pauli_mask(val: Union[Iterable['cirq.PAULI_GATE_LIKE'], np.ndarray]) ->
557617
return np.array([_pauli_index(v) for v in val], dtype=np.uint8)
558618

559619

560-
def _attempt_value_to_pauli_index(v: Any) -> Optional[Tuple[int, int]]:
620+
def _attempt_value_to_pauli_index(v: 'cirq.Operation') -> Optional[Tuple[int, int]]:
561621
if not isinstance(v, raw_types.Operation):
562622
return None
563623

Diff for: cirq/ops/dense_pauli_string_test.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,13 @@ def test_mul():
171171

172172
# Mixed types.
173173
m = cirq.MutableDensePauliString
174-
assert m('X') * m('Z') == -1j * f('Y')
175-
assert m('X') * f('Z') == -1j * f('Y')
174+
assert m('X') * m('Z') == -1j * m('Y')
175+
assert m('X') * f('Z') == -1j * m('Y')
176+
assert f('X') * m('Z') == -1j * m('Y')
176177
assert isinstance(f('') * f(''), cirq.DensePauliString)
177-
assert isinstance(m('') * m(''), cirq.DensePauliString)
178-
assert isinstance(m('') * f(''), cirq.DensePauliString)
178+
assert isinstance(m('') * m(''), cirq.MutableDensePauliString)
179+
assert isinstance(m('') * f(''), cirq.MutableDensePauliString)
180+
assert isinstance(f('') * m(''), cirq.MutableDensePauliString)
179181

180182
# Different lengths.
181183
assert f('I') * f('III') == f('III')
@@ -482,8 +484,9 @@ def test_tensor_product():
482484
f = cirq.DensePauliString
483485
m = cirq.MutableDensePauliString
484486
assert (2 * f('XX')).tensor_product(-f('XI')) == -2 * f('XXXI')
485-
assert m('XX', coefficient=2).tensor_product(-f('XI')) == -2 * f('XXXI')
486-
assert m('XX', coefficient=2).tensor_product(m('XI', coefficient=-1)) == -2 * f('XXXI')
487+
assert m('XX', coefficient=2).tensor_product(-f('XI')) == -2 * m('XXXI')
488+
assert f('XX', coefficient=2).tensor_product(-m('XI')) == -2 * f('XXXI')
489+
assert m('XX', coefficient=2).tensor_product(m('XI', coefficient=-1)) == -2 * m('XXXI')
487490

488491

489492
def test_commutes():
@@ -633,9 +636,15 @@ def test_idiv():
633636
def test_symbolic():
634637
t = sympy.Symbol('t')
635638
r = sympy.Symbol('r')
636-
p = cirq.MutableDensePauliString('XYZ', coefficient=t)
637-
assert p * r == cirq.DensePauliString('XYZ', coefficient=t * r)
638-
p *= r
639-
assert p == cirq.MutableDensePauliString('XYZ', coefficient=t * r)
640-
p /= r
641-
assert p == cirq.MutableDensePauliString('XYZ', coefficient=t)
639+
m = cirq.MutableDensePauliString('XYZ', coefficient=t)
640+
f = cirq.DensePauliString('XYZ', coefficient=t)
641+
assert f * r == cirq.DensePauliString('XYZ', coefficient=t * r)
642+
assert m * r == cirq.MutableDensePauliString('XYZ', coefficient=t * r)
643+
m *= r
644+
f *= r
645+
assert m == cirq.MutableDensePauliString('XYZ', coefficient=t * r)
646+
assert f == cirq.DensePauliString('XYZ', coefficient=t * r)
647+
m /= r
648+
f /= r
649+
assert m == cirq.MutableDensePauliString('XYZ', coefficient=t)
650+
assert f == cirq.DensePauliString('XYZ', coefficient=t)

0 commit comments

Comments
 (0)