Skip to content
This repository was archived by the owner on Feb 21, 2023. It is now read-only.

Commit 57a21bb

Browse files
zinter (redis/redis-py#1520) * Made some adjustments to typing in _zaggregate Signed-off-by: Andrew-Chen-Wang <[email protected]>
1 parent dde66f3 commit 57a21bb

File tree

2 files changed

+63
-8
lines changed

2 files changed

+63
-8
lines changed

aioredis/client.py

+41-8
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,8 @@ class Redis:
672672
"SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set()
673673
),
674674
**string_keys_to_dict(
675-
"ZPOPMAX ZPOPMIN ZDIFF ZRANGE ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE",
675+
"ZPOPMAX ZPOPMIN ZINTER ZDIFF ZRANGE ZRANGEBYSCORE ZREVRANGE "
676+
"ZREVRANGEBYSCORE",
676677
zset_score_pairs,
677678
),
678679
**string_keys_to_dict(
@@ -3260,16 +3261,39 @@ def zincrby(self, name: KeyT, amount: float, value: EncodableT) -> Awaitable:
32603261
"""Increment the score of ``value`` in sorted set ``name`` by ``amount``"""
32613262
return self.execute_command("ZINCRBY", name, amount, value)
32623263

3264+
def zinter(
3265+
self,
3266+
keys: KeysT,
3267+
aggregate: Optional[str] = None,
3268+
withscores: bool = False
3269+
) -> Awaitable:
3270+
"""
3271+
Return the intersect of multiple sorted sets specified by ``keys``.
3272+
With the ``aggregate`` option, it is possible to specify how the
3273+
results of the union are aggregated. This option defaults to SUM,
3274+
where the score of an element is summed across the inputs where it
3275+
exists. When this option is set to either MIN or MAX, the resulting
3276+
set will contain the minimum or maximum score of an element across
3277+
the inputs where it exists.
3278+
"""
3279+
return self._zaggregate(
3280+
"ZINTER", None, keys, aggregate, withscores=withscores
3281+
)
3282+
32633283
def zinterstore(
32643284
self,
32653285
dest: KeyT,
32663286
keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]],
32673287
aggregate: Optional[str] = None,
32683288
) -> Awaitable:
32693289
"""
3270-
Intersect multiple sorted sets specified by ``keys`` into
3271-
a new sorted set, ``dest``. Scores in the destination will be
3272-
aggregated based on the ``aggregate``, or SUM if none is provided.
3290+
Intersect multiple sorted sets specified by ``keys`` into a new
3291+
sorted set, ``dest``. Scores in the destination will be aggregated
3292+
based on the ``aggregate``. This option defaults to SUM, where the
3293+
score of an element is summed across the inputs where it exists.
3294+
When this option is set to either MIN or MAX, the resulting set will
3295+
contain the minimum or maximum score of an element across the inputs
3296+
where it exists.
32733297
"""
32743298
return self._zaggregate("ZINTERSTORE", dest, keys, aggregate)
32753299

@@ -3593,11 +3617,15 @@ def zunionstore(
35933617
def _zaggregate(
35943618
self,
35953619
command: str,
3596-
dest: KeyT,
3620+
dest: Optional[KeyT],
35973621
keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]],
35983622
aggregate: Optional[str] = None,
3623+
**options,
35993624
) -> Awaitable:
3600-
pieces: List[EncodableT] = [command, dest, len(keys)]
3625+
pieces: List[EncodableT] = [command]
3626+
if dest is not None:
3627+
pieces.append(dest)
3628+
pieces.append(len(keys))
36013629
if isinstance(keys, dict):
36023630
keys, weights = keys.keys(), keys.values()
36033631
else:
@@ -3607,8 +3635,13 @@ def _zaggregate(
36073635
pieces.append(b"WEIGHTS")
36083636
pieces.extend(weights)
36093637
if aggregate:
3610-
pieces.append(b"AGGREGATE")
3611-
pieces.append(aggregate)
3638+
if aggregate.upper() in ["SUM", "MIN", "MAX"]:
3639+
pieces.append(b'AGGREGATE')
3640+
pieces.append(aggregate)
3641+
else:
3642+
raise DataError("aggregate can be sum, min, or max")
3643+
if options.get("withscores", False):
3644+
pieces.append(b'WITHSCORES')
36123645
return self.execute_command(*pieces)
36133646

36143647
# HYPERLOGLOG COMMANDS

tests/test_commands.py

+22
Original file line numberDiff line numberDiff line change
@@ -1596,6 +1596,28 @@ async def test_zlexcount(self, r: aioredis.Redis):
15961596
assert await r.zlexcount("a", "-", "+") == 7
15971597
assert await r.zlexcount("a", "[b", "[f") == 5
15981598

1599+
@skip_if_server_version_lt('6.2.0')
1600+
async def test_zinter(self, r: aioredis.Redis):
1601+
await r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 1})
1602+
await r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2})
1603+
await r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4})
1604+
assert await r.zinter(['a', 'b', 'c']) == [b'a3', b'a1']
1605+
# invalid aggregation
1606+
with pytest.raises(exceptions.DataError):
1607+
await r.zinter(['a', 'b', 'c'], aggregate='foo', withscores=True)
1608+
# aggregate with SUM
1609+
assert await r.zinter(['a', 'b', 'c'], withscores=True) \
1610+
== [(b'a3', 8), (b'a1', 9)]
1611+
# aggregate with MAX
1612+
assert await r.zinter(['a', 'b', 'c'], aggregate='MAX', withscores=True) \
1613+
== [(b'a3', 5), (b'a1', 6)]
1614+
# aggregate with MIN
1615+
assert await r.zinter(['a', 'b', 'c'], aggregate='MIN', withscores=True) \
1616+
== [(b'a1', 1), (b'a3', 1)]
1617+
# with weights
1618+
assert await r.zinter({'a': 1, 'b': 2, 'c': 3}, withscores=True) \
1619+
== [(b'a3', 20), (b'a1', 23)]
1620+
15991621
async def test_zinterstore_sum(self, r: aioredis.Redis):
16001622
await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1})
16011623
await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2})

0 commit comments

Comments
 (0)