Skip to content

Rename cirq.channel to cirq.kraus #4195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@
definitely_commutes,
equal_up_to_global_phase,
has_channel,
has_kraus,
has_mixture,
has_stabilizer_effect,
has_unitary,
Expand All @@ -487,6 +488,7 @@
is_parameterized,
JsonResolver,
json_serializable_dataclass,
kraus,
measurement_key,
measurement_keys,
mixture,
Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/contrib/quimb/density_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def circuit_to_density_matrix_tensors(

Args:
circuit: The circuit containing operations that support the
cirq.unitary() or cirq.channel() protocols.
cirq.unitary() or cirq.kraus() protocols.
qubits: The qubits in the circuit.

Returns:
Expand Down Expand Up @@ -161,8 +161,8 @@ def _positions(mi, qubits):
tags={f'Q{len(op.qubits)}', f'i{mi + 1}b', _qpos_tag(op.qubits)},
)
)
elif cirq.has_channel(op):
K = np.asarray(cirq.channel(op), dtype=np.complex128)
elif cirq.has_kraus(op):
K = np.asarray(cirq.kraus(op), dtype=np.complex128)
kraus_inds = [f'k{kraus_frontier}']
tensors.append(
qtn.Tensor(
Expand Down
44 changes: 22 additions & 22 deletions cirq-core/cirq/ops/common_channels_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def assert_mixtures_equal(actual, expected):
def test_asymmetric_depolarizing_channel():
d = cirq.asymmetric_depolarize(0.1, 0.2, 0.3)
np.testing.assert_almost_equal(
cirq.channel(d),
cirq.kraus(d),
(np.sqrt(0.4) * np.eye(2), np.sqrt(0.1) * X, np.sqrt(0.2) * Y, np.sqrt(0.3) * Z),
)
assert cirq.has_channel(d)
assert cirq.has_kraus(d)


def test_asymmetric_depolarizing_mixture():
Expand Down Expand Up @@ -134,21 +134,21 @@ def test_asymmetric_depolarizing_channel_text_diagram():
def test_depolarizing_channel():
d = cirq.depolarize(0.3)
np.testing.assert_almost_equal(
cirq.channel(d),
cirq.kraus(d),
(
np.sqrt(0.7) * np.eye(2),
np.sqrt(0.1) * X,
np.sqrt(0.1) * Y,
np.sqrt(0.1) * Z,
),
)
assert cirq.has_channel(d)
assert cirq.has_kraus(d)


def test_depolarizing_channel_two_qubits():
d = cirq.depolarize(0.15, n_qubits=2)
np.testing.assert_almost_equal(
cirq.channel(d),
cirq.kraus(d),
(
np.sqrt(0.85) * np.eye(4),
np.sqrt(0.01) * np.kron(np.eye(2), X),
Expand All @@ -168,7 +168,7 @@ def test_depolarizing_channel_two_qubits():
np.sqrt(0.01) * np.kron(Z, Z),
),
)
assert cirq.has_channel(d)
assert cirq.has_kraus(d)

assert d.num_qubits() == 2
cirq.testing.assert_has_diagram(
Expand Down Expand Up @@ -310,15 +310,15 @@ def test_depolarizing_channel_text_diagram_two_qubits():
def test_generalized_amplitude_damping_channel():
d = cirq.generalized_amplitude_damp(0.1, 0.3)
np.testing.assert_almost_equal(
cirq.channel(d),
cirq.kraus(d),
(
np.sqrt(0.1) * np.array([[1.0, 0.0], [0.0, np.sqrt(1.0 - 0.3)]]),
np.sqrt(0.1) * np.array([[0.0, np.sqrt(0.3)], [0.0, 0.0]]),
np.sqrt(0.9) * np.array([[np.sqrt(1.0 - 0.3), 0.0], [0.0, 1.0]]),
np.sqrt(0.9) * np.array([[0.0, 0.0], [np.sqrt(0.3), 0.0]]),
),
)
assert cirq.has_channel(d)
assert cirq.has_kraus(d)
assert not cirq.has_mixture(d)


Expand Down Expand Up @@ -376,13 +376,13 @@ def test_generalized_amplitude_damping_channel_text_diagram():
def test_amplitude_damping_channel():
d = cirq.amplitude_damp(0.3)
np.testing.assert_almost_equal(
cirq.channel(d),
cirq.kraus(d),
(
np.array([[1.0, 0.0], [0.0, np.sqrt(1.0 - 0.3)]]),
np.array([[0.0, np.sqrt(0.3)], [0.0, 0.0]]),
),
)
assert cirq.has_channel(d)
assert cirq.has_kraus(d)
assert not cirq.has_mixture(d)


Expand Down Expand Up @@ -432,22 +432,22 @@ def test_amplitude_damping_channel_text_diagram():
def test_reset_channel():
r = cirq.reset(cirq.LineQubit(0))
np.testing.assert_almost_equal(
cirq.channel(r), (np.array([[1.0, 0.0], [0.0, 0]]), np.array([[0.0, 1.0], [0.0, 0.0]]))
cirq.kraus(r), (np.array([[1.0, 0.0], [0.0, 0]]), np.array([[0.0, 1.0], [0.0, 0.0]]))
)
assert cirq.has_channel(r)
assert cirq.has_kraus(r)
assert not cirq.has_mixture(r)
assert cirq.qid_shape(r) == (2,)

r = cirq.reset(cirq.LineQid(0, dimension=3))
np.testing.assert_almost_equal(
cirq.channel(r),
cirq.kraus(r),
(
np.array([[1, 0, 0], [0, 0, 0], [0, 0, 0]]),
np.array([[0, 1, 0], [0, 0, 0], [0, 0, 0]]),
np.array([[0, 0, 1], [0, 0, 0], [0, 0, 0]]),
),
) # yapf: disable
assert cirq.has_channel(r)
assert cirq.has_kraus(r)
assert not cirq.has_mixture(r)
assert cirq.qid_shape(r) == (3,)

Expand Down Expand Up @@ -508,13 +508,13 @@ def test_reset_act_on():
def test_phase_damping_channel():
d = cirq.phase_damp(0.3)
np.testing.assert_almost_equal(
cirq.channel(d),
cirq.kraus(d),
(
np.array([[1.0, 0.0], [0.0, np.sqrt(1 - 0.3)]]),
np.array([[0.0, 0.0], [0.0, np.sqrt(0.3)]]),
),
)
assert cirq.has_channel(d)
assert cirq.has_kraus(d)
assert not cirq.has_mixture(d)


Expand Down Expand Up @@ -564,9 +564,9 @@ def test_phase_damping_channel_text_diagram():
def test_phase_flip_channel():
d = cirq.phase_flip(0.3)
np.testing.assert_almost_equal(
cirq.channel(d), (np.sqrt(1.0 - 0.3) * np.eye(2), np.sqrt(0.3) * Z)
cirq.kraus(d), (np.sqrt(1.0 - 0.3) * np.eye(2), np.sqrt(0.3) * Z)
)
assert cirq.has_channel(d)
assert cirq.has_kraus(d)


def test_phase_flip_mixture():
Expand Down Expand Up @@ -628,9 +628,9 @@ def test_phase_flip_channel_text_diagram():
def test_bit_flip_channel():
d = cirq.bit_flip(0.3)
np.testing.assert_almost_equal(
cirq.channel(d), (np.sqrt(1.0 - 0.3) * np.eye(2), np.sqrt(0.3) * X)
cirq.kraus(d), (np.sqrt(1.0 - 0.3) * np.eye(2), np.sqrt(0.3) * X)
)
assert cirq.has_channel(d)
assert cirq.has_kraus(d)


def test_bit_flip_mixture():
Expand Down Expand Up @@ -734,9 +734,9 @@ def test_missing_prob_mass():
def test_multi_asymmetric_depolarizing_channel():
d = cirq.asymmetric_depolarize(error_probabilities={'II': 0.8, 'XX': 0.2})
np.testing.assert_almost_equal(
cirq.channel(d), (np.sqrt(0.8) * np.eye(4), np.sqrt(0.2) * np.kron(X, X))
cirq.kraus(d), (np.sqrt(0.8) * np.eye(4), np.sqrt(0.2) * np.kron(X, X))
)
assert cirq.has_channel(d)
assert cirq.has_kraus(d)
np.testing.assert_equal(d._num_qubits_(), 2)

with pytest.raises(ValueError, match="num_qubits should be 1"):
Expand Down
8 changes: 4 additions & 4 deletions cirq-core/cirq/ops/gate_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,11 @@ def test_unitary():
def test_channel():
a = cirq.NamedQubit('a')
op = cirq.bit_flip(0.5).on(a)
np.testing.assert_allclose(cirq.channel(op), cirq.channel(op.gate))
assert cirq.has_channel(op)
np.testing.assert_allclose(cirq.kraus(op), cirq.kraus(op.gate))
assert cirq.has_kraus(op)

assert cirq.channel(cirq.SingleQubitGate()(a), None) is None
assert not cirq.has_channel(cirq.SingleQubitGate()(a))
assert cirq.kraus(cirq.SingleQubitGate()(a), None) is None
assert not cirq.has_kraus(cirq.SingleQubitGate()(a))


def test_measurement_key():
Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/ops/measurement_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,12 @@ def test_measurement_gate_diagram():

def test_measurement_channel():
np.testing.assert_allclose(
cirq.channel(cirq.MeasurementGate(1)),
cirq.kraus(cirq.MeasurementGate(1)),
(np.array([[1, 0], [0, 0]]), np.array([[0, 0], [0, 1]])),
)
# yapf: disable
np.testing.assert_allclose(
cirq.channel(cirq.MeasurementGate(2)),
cirq.kraus(cirq.MeasurementGate(2)),
(np.array([[1, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
Expand All @@ -224,7 +224,7 @@ def test_measurement_channel():
[0, 0, 0, 0],
[0, 0, 0, 1]])))
np.testing.assert_allclose(
cirq.channel(cirq.MeasurementGate(2, qid_shape=(2, 3))),
cirq.kraus(cirq.MeasurementGate(2, qid_shape=(2, 3))),
(np.diag([1, 0, 0, 0, 0, 0]),
np.diag([0, 1, 0, 0, 0, 0]),
np.diag([0, 0, 1, 0, 0, 0]),
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/ops/random_gate_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _has_mixture_(self):
return not self._is_parameterized_() and protocols.has_mixture(self.sub_gate)

def _has_kraus_(self):
return not self._is_parameterized_() and protocols.has_channel(self.sub_gate)
return not self._is_parameterized_() and protocols.has_kraus(self.sub_gate)

def _is_parameterized_(self) -> bool:
return protocols.is_parameterized(self.probability) or protocols.is_parameterized(
Expand Down Expand Up @@ -94,7 +94,7 @@ def _kraus_(self):
if self._is_parameterized_():
return NotImplemented

channel = protocols.channel(self.sub_gate, None)
channel = protocols.kraus(self.sub_gate, None)
if channel is None:
return NotImplemented

Expand Down
18 changes: 9 additions & 9 deletions cirq-core/cirq/ops/random_gate_channel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ def num_qubits(self) -> int:
def test_parameterized(resolve_fn):
op = cirq.X.with_probability(sympy.Symbol('x'))
assert cirq.is_parameterized(op)
assert not cirq.has_channel(op)
assert not cirq.has_kraus(op)
assert not cirq.has_mixture(op)

op2 = resolve_fn(op, {'x': 0.5})
assert op2 == cirq.X.with_probability(0.5)
assert not cirq.is_parameterized(op2)
assert cirq.has_channel(op2)
assert cirq.has_kraus(op2)
assert cirq.has_mixture(op2)


Expand All @@ -158,7 +158,7 @@ def num_qubits(self) -> int:


def assert_channel_sums_to_identity(val):
m = cirq.channel(val)
m = cirq.kraus(val)
s = sum(np.conj(e.T) @ e for e in m)
np.testing.assert_allclose(s, np.eye(np.product(cirq.qid_shape(val))), atol=1e-8)

Expand All @@ -168,14 +168,14 @@ class NoDetailsGate(cirq.Gate):
def num_qubits(self) -> int:
return 1

assert not cirq.has_channel(NoDetailsGate().with_probability(0.5))
assert cirq.channel(NoDetailsGate().with_probability(0.5), None) is None
assert cirq.channel(cirq.X.with_probability(sympy.Symbol('x')), None) is None
assert not cirq.has_kraus(NoDetailsGate().with_probability(0.5))
assert cirq.kraus(NoDetailsGate().with_probability(0.5), None) is None
assert cirq.kraus(cirq.X.with_probability(sympy.Symbol('x')), None) is None
assert_channel_sums_to_identity(cirq.X.with_probability(0.25))
assert_channel_sums_to_identity(cirq.bit_flip(0.75).with_probability(0.25))
assert_channel_sums_to_identity(cirq.amplitude_damp(0.75).with_probability(0.25))

m = cirq.channel(cirq.X.with_probability(0.25))
m = cirq.kraus(cirq.X.with_probability(0.25))
assert len(m) == 2
np.testing.assert_allclose(
m[0],
Expand All @@ -188,7 +188,7 @@ def num_qubits(self) -> int:
atol=1e-8,
)

m = cirq.channel(cirq.bit_flip(0.75).with_probability(0.25))
m = cirq.kraus(cirq.bit_flip(0.75).with_probability(0.25))
assert len(m) == 3
np.testing.assert_allclose(
m[0],
Expand All @@ -206,7 +206,7 @@ def num_qubits(self) -> int:
atol=1e-8,
)

m = cirq.channel(cirq.amplitude_damp(0.75).with_probability(0.25))
m = cirq.kraus(cirq.amplitude_damp(0.75).with_probability(0.25))
assert len(m) == 3
np.testing.assert_allclose(
m[0],
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,10 +656,10 @@ def _mixture_(self) -> Sequence[Tuple[float, Any]]:
return protocols.mixture(self.sub_operation, NotImplemented)

def _has_kraus_(self) -> bool:
return protocols.has_channel(self.sub_operation)
return protocols.has_kraus(self.sub_operation)

def _kraus_(self) -> Union[Tuple[np.ndarray], NotImplementedType]:
return protocols.channel(self.sub_operation, NotImplemented)
return protocols.kraus(self.sub_operation, NotImplemented)

def _measurement_key_(self) -> str:
return protocols.measurement_key(self.sub_operation, NotImplemented)
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/ops/raw_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def test_tagged_operation_forwards_protocols():
assert cirq.decompose(tagged_h) == cirq.decompose(h)
assert cirq.pauli_expansion(tagged_h) == cirq.pauli_expansion(h)
assert cirq.equal_up_to_global_phase(h, tagged_h)
assert np.isclose(cirq.channel(h), cirq.channel(tagged_h)).all()
assert np.isclose(cirq.kraus(h), cirq.kraus(tagged_h)).all()

assert cirq.measurement_key(cirq.measure(q1, key='blah').with_tags(tag)) == 'blah'

Expand Down Expand Up @@ -631,7 +631,7 @@ def test_tagged_operation_forwards_protocols():
flip = cirq.bit_flip(0.5)(q1)
tagged_flip = cirq.bit_flip(0.5)(q1).with_tags(tag)
assert cirq.has_mixture(tagged_flip)
assert cirq.has_channel(tagged_flip)
assert cirq.has_kraus(tagged_flip)

flip_mixture = cirq.mixture(flip)
tagged_mixture = cirq.mixture(tagged_flip)
Expand Down
3 changes: 3 additions & 0 deletions cirq-core/cirq/protocols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@
)
from cirq.protocols.channel import (
channel,
kraus,
has_channel,
has_kraus,
SupportsChannel,
SupportsKraus,
)
from cirq.protocols.commutes_protocol import (
commutes,
Expand Down
8 changes: 4 additions & 4 deletions cirq-core/cirq/protocols/apply_channel_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
apply_unitary,
ApplyUnitaryArgs,
)
from cirq.protocols.channel import channel
from cirq.protocols.channel import kraus
from cirq.protocols import qid_shape_protocol
from cirq.type_workarounds import NotImplementedType

Expand Down Expand Up @@ -257,9 +257,9 @@ def err_str(buf_num_str):
return result

# Fallback to using the object's `_kraus_` matrices.
kraus = channel(val, None)
if kraus is not None:
return _apply_kraus(kraus, args)
ks = kraus(val, None)
if ks is not None:
return _apply_kraus(ks, args)

# Don't know how to apply channel. Fallback to specified default behavior.
if default is not RaiseTypeErrorIfNotProvided:
Expand Down
Loading