Skip to content

Commit f5eef16

Browse files
authored
Fix #273 -- Use consistent hashing for PubSub (#274)
* Fix #273 -- Use consistent hashing for PubSub * Update pubsub.py * Add missing import * Refactor hash function into utils module and add test * Run black
1 parent 49e419d commit f5eef16

File tree

4 files changed

+43
-15
lines changed

4 files changed

+43
-15
lines changed

channels_redis/core.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
import base64
3-
import binascii
43
import collections
54
import functools
65
import hashlib
@@ -18,6 +17,8 @@
1817
from channels.exceptions import ChannelFull
1918
from channels.layers import BaseChannelLayer
2019

20+
from .utils import _consistent_hash
21+
2122
logger = logging.getLogger(__name__)
2223

2324
AIOREDIS_VERSION = tuple(map(int, aioredis.__version__.split(".")))
@@ -858,15 +859,7 @@ def deserialize(self, message):
858859
### Internal functions ###
859860

860861
def consistent_hash(self, value):
861-
"""
862-
Maps the value to a node value between 0 and 4095
863-
using CRC, then down to one of the ring nodes.
864-
"""
865-
if isinstance(value, str):
866-
value = value.encode("utf8")
867-
bigval = binascii.crc32(value) & 0xFFF
868-
ring_divisor = 4096 / float(self.ring_size)
869-
return int(bigval / ring_divisor)
862+
return _consistent_hash(value, self.ring_size)
870863

871864
def make_fernet(self, key):
872865
"""

channels_redis/pubsub.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import aioredis
99
import msgpack
1010

11+
from .utils import _consistent_hash
12+
1113
logger = logging.getLogger(__name__)
1214

1315

@@ -106,11 +108,7 @@ def _get_shard(self, channel_or_group_name):
106108
"""
107109
Return the shard that is used exclusively for this channel or group.
108110
"""
109-
if len(self._shards) == 1:
110-
# Avoid the overhead of hashing and modulo when it is unnecessary.
111-
return self._shards[0]
112-
shard_index = abs(hash(channel_or_group_name)) % len(self._shards)
113-
return self._shards[shard_index]
111+
return self._shards[_consistent_hash(channel_or_group_name, len(self._shards))]
114112

115113
def _get_group_channel_name(self, group):
116114
"""

channels_redis/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import binascii
2+
3+
4+
def _consistent_hash(value, ring_size):
5+
"""
6+
Maps the value to a node value between 0 and 4095
7+
using CRC, then down to one of the ring nodes.
8+
"""
9+
if ring_size == 1:
10+
# Avoid the overhead of hashing and modulo when it is unnecessary.
11+
return 0
12+
13+
if isinstance(value, str):
14+
value = value.encode("utf8")
15+
bigval = binascii.crc32(value) & 0xFFF
16+
ring_divisor = 4096 / float(ring_size)
17+
return int(bigval / ring_divisor)

tests/test_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import pytest
2+
3+
from channels_redis.utils import _consistent_hash
4+
5+
6+
@pytest.mark.parametrize(
7+
"value,ring_size,expected",
8+
[
9+
("key_one", 1, 0),
10+
("key_two", 1, 0),
11+
("key_one", 2, 1),
12+
("key_two", 2, 0),
13+
("key_one", 10, 6),
14+
("key_two", 10, 4),
15+
(b"key_one", 10, 6),
16+
(b"key_two", 10, 4),
17+
],
18+
)
19+
def test_consistent_hash_result(value, ring_size, expected):
20+
assert _consistent_hash(value, ring_size) == expected

0 commit comments

Comments
 (0)