Skip to content

Commit 6fcaa0a

Browse files
authored
Add type hints to Redis (#3110)
1 parent cc62d1f commit 6fcaa0a

File tree

2 files changed

+79
-26
lines changed
  • instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis

2 files changed

+79
-26
lines changed

Diff for: instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py

+79-26
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,9 @@ def response_hook(span, instance, response):
9191
---
9292
"""
9393

94-
import typing
95-
from typing import Any, Collection
94+
from __future__ import annotations
95+
96+
from typing import TYPE_CHECKING, Any, Callable, Collection
9697

9798
import redis
9899
from wrapt import wrap_function_wrapper
@@ -109,18 +110,43 @@ def response_hook(span, instance, response):
109110
from opentelemetry.instrumentation.redis.version import __version__
110111
from opentelemetry.instrumentation.utils import unwrap
111112
from opentelemetry.semconv.trace import SpanAttributes
112-
from opentelemetry.trace import Span, StatusCode
113+
from opentelemetry.trace import Span, StatusCode, Tracer
113114

114-
_DEFAULT_SERVICE = "redis"
115+
if TYPE_CHECKING:
116+
from typing import Awaitable, TypeVar
115117

116-
_RequestHookT = typing.Optional[
117-
typing.Callable[
118-
[Span, redis.connection.Connection, typing.List, typing.Dict], None
118+
import redis.asyncio.client
119+
import redis.asyncio.cluster
120+
import redis.client
121+
import redis.cluster
122+
import redis.connection
123+
124+
_RequestHookT = Callable[
125+
[Span, redis.connection.Connection, list[Any], dict[str, Any]], None
119126
]
120-
]
121-
_ResponseHookT = typing.Optional[
122-
typing.Callable[[Span, redis.connection.Connection, Any], None]
123-
]
127+
_ResponseHookT = Callable[[Span, redis.connection.Connection, Any], None]
128+
129+
AsyncPipelineInstance = TypeVar(
130+
"AsyncPipelineInstance",
131+
redis.asyncio.client.Pipeline,
132+
redis.asyncio.cluster.ClusterPipeline,
133+
)
134+
AsyncRedisInstance = TypeVar(
135+
"AsyncRedisInstance", redis.asyncio.Redis, redis.asyncio.RedisCluster
136+
)
137+
PipelineInstance = TypeVar(
138+
"PipelineInstance",
139+
redis.client.Pipeline,
140+
redis.cluster.ClusterPipeline,
141+
)
142+
RedisInstance = TypeVar(
143+
"RedisInstance", redis.client.Redis, redis.cluster.RedisCluster
144+
)
145+
R = TypeVar("R")
146+
147+
148+
_DEFAULT_SERVICE = "redis"
149+
124150

125151
_REDIS_ASYNCIO_VERSION = (4, 2, 0)
126152
if redis.VERSION >= _REDIS_ASYNCIO_VERSION:
@@ -132,7 +158,9 @@ def response_hook(span, instance, response):
132158
_FIELD_TYPES = ["NUMERIC", "TEXT", "GEO", "TAG", "VECTOR"]
133159

134160

135-
def _set_connection_attributes(span, conn):
161+
def _set_connection_attributes(
162+
span: Span, conn: RedisInstance | AsyncRedisInstance
163+
) -> None:
136164
if not span.is_recording() or not hasattr(conn, "connection_pool"):
137165
return
138166
for key, value in _extract_conn_attributes(
@@ -141,7 +169,9 @@ def _set_connection_attributes(span, conn):
141169
span.set_attribute(key, value)
142170

143171

144-
def _build_span_name(instance, cmd_args):
172+
def _build_span_name(
173+
instance: RedisInstance | AsyncRedisInstance, cmd_args: tuple[Any, ...]
174+
) -> str:
145175
if len(cmd_args) > 0 and cmd_args[0]:
146176
if cmd_args[0] == "FT.SEARCH":
147177
name = "redis.search"
@@ -154,7 +184,9 @@ def _build_span_name(instance, cmd_args):
154184
return name
155185

156186

157-
def _build_span_meta_data_for_pipeline(instance):
187+
def _build_span_meta_data_for_pipeline(
188+
instance: PipelineInstance | AsyncPipelineInstance,
189+
) -> tuple[list[Any], str, str]:
158190
try:
159191
command_stack = (
160192
instance.command_stack
@@ -184,11 +216,16 @@ def _build_span_meta_data_for_pipeline(instance):
184216

185217
# pylint: disable=R0915
186218
def _instrument(
187-
tracer,
188-
request_hook: _RequestHookT = None,
189-
response_hook: _ResponseHookT = None,
219+
tracer: Tracer,
220+
request_hook: _RequestHookT | None = None,
221+
response_hook: _ResponseHookT | None = None,
190222
):
191-
def _traced_execute_command(func, instance, args, kwargs):
223+
def _traced_execute_command(
224+
func: Callable[..., R],
225+
instance: RedisInstance,
226+
args: tuple[Any, ...],
227+
kwargs: dict[str, Any],
228+
) -> R:
192229
query = _format_command_args(args)
193230
name = _build_span_name(instance, args)
194231
with tracer.start_as_current_span(
@@ -210,7 +247,12 @@ def _traced_execute_command(func, instance, args, kwargs):
210247
response_hook(span, instance, response)
211248
return response
212249

213-
def _traced_execute_pipeline(func, instance, args, kwargs):
250+
def _traced_execute_pipeline(
251+
func: Callable[..., R],
252+
instance: PipelineInstance,
253+
args: tuple[Any, ...],
254+
kwargs: dict[str, Any],
255+
) -> R:
214256
(
215257
command_stack,
216258
resource,
@@ -242,7 +284,7 @@ def _traced_execute_pipeline(func, instance, args, kwargs):
242284

243285
return response
244286

245-
def _add_create_attributes(span, args):
287+
def _add_create_attributes(span: Span, args: tuple[Any, ...]):
246288
_set_span_attribute_if_value(
247289
span, "redis.create_index.index", _value_or_none(args, 1)
248290
)
@@ -266,7 +308,7 @@ def _add_create_attributes(span, args):
266308
field_attribute,
267309
)
268310

269-
def _add_search_attributes(span, response, args):
311+
def _add_search_attributes(span: Span, response, args):
270312
_set_span_attribute_if_value(
271313
span, "redis.search.index", _value_or_none(args, 1)
272314
)
@@ -326,7 +368,12 @@ def _add_search_attributes(span, response, args):
326368
_traced_execute_pipeline,
327369
)
328370

329-
async def _async_traced_execute_command(func, instance, args, kwargs):
371+
async def _async_traced_execute_command(
372+
func: Callable[..., Awaitable[R]],
373+
instance: AsyncRedisInstance,
374+
args: tuple[Any, ...],
375+
kwargs: dict[str, Any],
376+
) -> Awaitable[R]:
330377
query = _format_command_args(args)
331378
name = _build_span_name(instance, args)
332379

@@ -344,7 +391,12 @@ async def _async_traced_execute_command(func, instance, args, kwargs):
344391
response_hook(span, instance, response)
345392
return response
346393

347-
async def _async_traced_execute_pipeline(func, instance, args, kwargs):
394+
async def _async_traced_execute_pipeline(
395+
func: Callable[..., Awaitable[R]],
396+
instance: AsyncPipelineInstance,
397+
args: tuple[Any, ...],
398+
kwargs: dict[str, Any],
399+
) -> Awaitable[R]:
348400
(
349401
command_stack,
350402
resource,
@@ -408,14 +460,15 @@ async def _async_traced_execute_pipeline(func, instance, args, kwargs):
408460

409461

410462
class RedisInstrumentor(BaseInstrumentor):
411-
"""An instrumentor for Redis
463+
"""An instrumentor for Redis.
464+
412465
See `BaseInstrumentor`
413466
"""
414467

415468
def instrumentation_dependencies(self) -> Collection[str]:
416469
return _instruments
417470

418-
def _instrument(self, **kwargs):
471+
def _instrument(self, **kwargs: Any):
419472
"""Instruments the redis module
420473
421474
Args:
@@ -436,7 +489,7 @@ def _instrument(self, **kwargs):
436489
response_hook=kwargs.get("response_hook"),
437490
)
438491

439-
def _uninstrument(self, **kwargs):
492+
def _uninstrument(self, **kwargs: Any):
440493
if redis.VERSION < (3, 0, 0):
441494
unwrap(redis.StrictRedis, "execute_command")
442495
unwrap(redis.StrictRedis, "pipeline")

Diff for: instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/py.typed

Whitespace-only changes.

0 commit comments

Comments
 (0)