Skip to content

Commit d805d82

Browse files
authored
Make InternalGate hashable if all gate args are hashable (#6294)
Review: @NoureldinYosri
1 parent 1366494 commit d805d82

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

cirq-google/cirq_google/ops/internal_gate.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343
self.gate_module = gate_module
4444
self.gate_name = gate_name
4545
self._num_qubits = num_qubits
46-
self.gate_args = {arg: val for arg, val in kwargs.items()}
46+
self.gate_args = kwargs
4747

4848
def _num_qubits_(self) -> int:
4949
return self._num_qubits
@@ -72,4 +72,15 @@ def _json_dict_(self) -> Dict[str, Any]:
7272
)
7373

7474
def _value_equality_values_(self):
75-
return (self.gate_module, self.gate_name, self._num_qubits, self.gate_args)
75+
hashable = True
76+
for arg in self.gate_args.values():
77+
try:
78+
hash(arg)
79+
except TypeError:
80+
hashable = False
81+
return (
82+
self.gate_module,
83+
self.gate_name,
84+
self._num_qubits,
85+
frozenset(self.gate_args.items()) if hashable else self.gate_args,
86+
)

cirq-google/cirq_google/ops/internal_gate_test.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import cirq
1616
import cirq_google
17+
import pytest
1718

1819

1920
def test_internal_gate():
@@ -39,7 +40,30 @@ def test_internal_gate_with_no_args():
3940
g = cirq_google.InternalGate(gate_name="GateWithNoArgs", gate_module='test', num_qubits=3)
4041
assert str(g) == 'test.GateWithNoArgs()'
4142
want_repr = (
42-
"cirq_google.InternalGate(gate_name='GateWithNoArgs', " "gate_module='test', num_qubits=3)"
43+
"cirq_google.InternalGate(gate_name='GateWithNoArgs', gate_module='test', num_qubits=3)"
4344
)
4445
assert repr(g) == want_repr
4546
assert cirq.qid_shape(g) == (2, 2, 2)
47+
48+
49+
def test_internal_gate_with_hashable_args_is_hashable():
50+
hashable = cirq_google.InternalGate(
51+
gate_name="GateWithHashableArgs",
52+
gate_module='test',
53+
num_qubits=3,
54+
foo=1,
55+
bar="2",
56+
baz=(("a", 1),),
57+
)
58+
_ = hash(hashable)
59+
60+
unhashable = cirq_google.InternalGate(
61+
gate_name="GateWithHashableArgs",
62+
gate_module='test',
63+
num_qubits=3,
64+
foo=1,
65+
bar="2",
66+
baz={"a": 1},
67+
)
68+
with pytest.raises(TypeError, match="unhashable"):
69+
_ = hash(unhashable)

0 commit comments

Comments
 (0)