Skip to content

Commit 174c1bb

Browse files
committed
added request and response hooks for grpc client
1 parent 4a859e3 commit 174c1bb

File tree

6 files changed

+380
-15
lines changed

6 files changed

+380
-15
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
- Fix exception in Urllib3 when dealing with filelike body.
1414
([#1399](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1399))
1515

16+
- Add request and response hooks for GRPC instrumentation (client only)
17+
([#14](https://github.com/helios/opentelemetry-python-contrib/pull/14))
18+
1619
### Added
1720

1821
- Add connection attributes to sqlalchemy connect span

instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py

+56-9
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,8 @@ def __init__(self, filter_=None):
434434
else:
435435
filter_ = any_of(filter_, excluded_service_filter)
436436
self._filter = filter_
437+
self._request_hook = None
438+
self._response_hook = None
437439

438440
# Figures out which channel type we need to wrap
439441
def _which_channel(self, kwargs):
@@ -455,6 +457,8 @@ def instrumentation_dependencies(self) -> Collection[str]:
455457
return _instruments
456458

457459
def _instrument(self, **kwargs):
460+
self._request_hook = kwargs.get("request_hook")
461+
self._response_hook = kwargs.get("response_hook")
458462
for ctype in self._which_channel(kwargs):
459463
_wrap(
460464
"grpc",
@@ -469,11 +473,15 @@ def _uninstrument(self, **kwargs):
469473
def wrapper_fn(self, original_func, instance, args, kwargs):
470474
channel = original_func(*args, **kwargs)
471475
tracer_provider = kwargs.get("tracer_provider")
476+
request_hook = self._request_hook
477+
response_hook = self._response_hook
472478
return intercept_channel(
473479
channel,
474480
client_interceptor(
475481
tracer_provider=tracer_provider,
476482
filter_=self._filter,
483+
request_hook=request_hook,
484+
response_hook=response_hook,
477485
),
478486
)
479487

@@ -499,6 +507,8 @@ def __init__(self, filter_=None):
499507
else:
500508
filter_ = any_of(filter_, excluded_service_filter)
501509
self._filter = filter_
510+
self._request_hook = None
511+
self._response_hook = None
502512

503513
def instrumentation_dependencies(self) -> Collection[str]:
504514
return _instruments
@@ -507,20 +517,28 @@ def _add_interceptors(self, tracer_provider, kwargs):
507517
if "interceptors" in kwargs and kwargs["interceptors"]:
508518
kwargs["interceptors"] = (
509519
aio_client_interceptors(
510-
tracer_provider=tracer_provider, filter_=self._filter
520+
tracer_provider=tracer_provider,
521+
filter_=self._filter,
522+
request_hook=self._request_hook,
523+
response_hook=self._response_hook,
511524
)
512525
+ kwargs["interceptors"]
513526
)
514527
else:
515528
kwargs["interceptors"] = aio_client_interceptors(
516-
tracer_provider=tracer_provider, filter_=self._filter
529+
tracer_provider=tracer_provider,
530+
filter_=self._filter,
531+
request_hook=self._request_hook,
532+
response_hook=self._response_hook,
517533
)
518534

519535
return kwargs
520536

521537
def _instrument(self, **kwargs):
522538
self._original_insecure = grpc.aio.insecure_channel
523539
self._original_secure = grpc.aio.secure_channel
540+
self._request_hook = kwargs.get("request_hook")
541+
self._response_hook = kwargs.get("response_hook")
524542
tracer_provider = kwargs.get("tracer_provider")
525543

526544
def insecure(*args, **kwargs):
@@ -541,7 +559,9 @@ def _uninstrument(self, **kwargs):
541559
grpc.aio.secure_channel = self._original_secure
542560

543561

544-
def client_interceptor(tracer_provider=None, filter_=None):
562+
def client_interceptor(
563+
tracer_provider=None, filter_=None, request_hook=None, response_hook=None
564+
):
545565
"""Create a gRPC client channel interceptor.
546566
547567
Args:
@@ -558,7 +578,12 @@ def client_interceptor(tracer_provider=None, filter_=None):
558578

559579
tracer = trace.get_tracer(__name__, __version__, tracer_provider)
560580

561-
return _client.OpenTelemetryClientInterceptor(tracer, filter_=filter_)
581+
return _client.OpenTelemetryClientInterceptor(
582+
tracer,
583+
filter_=filter_,
584+
request_hook=request_hook,
585+
response_hook=response_hook,
586+
)
562587

563588

564589
def server_interceptor(tracer_provider=None, filter_=None):
@@ -581,7 +606,9 @@ def server_interceptor(tracer_provider=None, filter_=None):
581606
return _server.OpenTelemetryServerInterceptor(tracer, filter_=filter_)
582607

583608

584-
def aio_client_interceptors(tracer_provider=None, filter_=None):
609+
def aio_client_interceptors(
610+
tracer_provider=None, filter_=None, request_hook=None, response_hook=None
611+
):
585612
"""Create a gRPC client channel interceptor.
586613
587614
Args:
@@ -595,10 +622,30 @@ def aio_client_interceptors(tracer_provider=None, filter_=None):
595622
tracer = trace.get_tracer(__name__, __version__, tracer_provider)
596623

597624
return [
598-
_aio_client.UnaryUnaryAioClientInterceptor(tracer, filter_=filter_),
599-
_aio_client.UnaryStreamAioClientInterceptor(tracer, filter_=filter_),
600-
_aio_client.StreamUnaryAioClientInterceptor(tracer, filter_=filter_),
601-
_aio_client.StreamStreamAioClientInterceptor(tracer, filter_=filter_),
625+
_aio_client.UnaryUnaryAioClientInterceptor(
626+
tracer,
627+
filter_=filter_,
628+
request_hook=request_hook,
629+
response_hook=response_hook,
630+
),
631+
_aio_client.UnaryStreamAioClientInterceptor(
632+
tracer,
633+
filter_=filter_,
634+
request_hook=request_hook,
635+
response_hook=response_hook,
636+
),
637+
_aio_client.StreamUnaryAioClientInterceptor(
638+
tracer,
639+
filter_=filter_,
640+
request_hook=request_hook,
641+
response_hook=response_hook,
642+
),
643+
_aio_client.StreamStreamAioClientInterceptor(
644+
tracer,
645+
filter_=filter_,
646+
request_hook=request_hook,
647+
response_hook=response_hook,
648+
),
602649
]
603650

604651

instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_aio_client.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import functools
16+
import logging
1617
from collections import OrderedDict
1718

1819
import grpc
@@ -28,8 +29,10 @@
2829
from opentelemetry.semconv.trace import SpanAttributes
2930
from opentelemetry.trace.status import Status, StatusCode
3031

32+
logger = logging.getLogger(__name__)
3133

32-
def _unary_done_callback(span, code, details):
34+
35+
def _unary_done_callback(span, code, details, response_hook):
3336
def callback(call):
3437
try:
3538
span.set_attribute(
@@ -43,6 +46,8 @@ def callback(call):
4346
description=details,
4447
)
4548
)
49+
response_hook(span, details)
50+
4651
finally:
4752
span.end()
4853

@@ -110,7 +115,11 @@ async def _wrap_unary_response(self, continuation, span):
110115
code = await call.code()
111116
details = await call.details()
112117

113-
call.add_done_callback(_unary_done_callback(span, code, details))
118+
call.add_done_callback(
119+
_unary_done_callback(
120+
span, code, details, self._call_response_hook
121+
)
122+
)
114123

115124
return call
116125
except grpc.aio.AioRpcError as exc:
@@ -120,6 +129,8 @@ async def _wrap_unary_response(self, continuation, span):
120129
async def _wrap_stream_response(self, span, call):
121130
try:
122131
async for response in call:
132+
if self._response_hook:
133+
self._call_response_hook(span, response)
123134
yield response
124135
except Exception as exc:
125136
self.add_error_details_to_span(span, exc)
@@ -151,6 +162,9 @@ async def intercept_unary_unary(
151162
) as span:
152163
new_details = self.propagate_trace_in_details(client_call_details)
153164

165+
if self._request_hook:
166+
self._call_request_hook(span, request)
167+
154168
continuation_with_args = functools.partial(
155169
continuation, new_details, request
156170
)
@@ -175,7 +189,8 @@ async def intercept_unary_stream(
175189
new_details = self.propagate_trace_in_details(client_call_details)
176190

177191
resp = await continuation(new_details, request)
178-
192+
if self._request_hook:
193+
self._call_request_hook(span, request)
179194
return self._wrap_stream_response(span, resp)
180195

181196

instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_client.py

+34-3
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919

2020
"""Implementation of the invocation-side open-telemetry interceptor."""
2121

22+
import logging
2223
from collections import OrderedDict
23-
from typing import MutableMapping
24+
from typing import Callable, MutableMapping
2425

2526
import grpc
2627

@@ -33,6 +34,8 @@
3334
from opentelemetry.semconv.trace import SpanAttributes
3435
from opentelemetry.trace.status import Status, StatusCode
3536

37+
logger = logging.getLogger(__name__)
38+
3639

3740
class _CarrierSetter(Setter):
3841
"""We use a custom setter in order to be able to lower case
@@ -59,12 +62,27 @@ def callback(response_future):
5962
return callback
6063

6164

65+
def _safe_invoke(function: Callable, *args):
66+
function_name = "<unknown>"
67+
try:
68+
function_name = function.__name__
69+
function(*args)
70+
except Exception as ex: # pylint:disable=broad-except
71+
logger.error(
72+
"Error when invoking function '%s'", function_name, exc_info=ex
73+
)
74+
75+
6276
class OpenTelemetryClientInterceptor(
6377
grpcext.UnaryClientInterceptor, grpcext.StreamClientInterceptor
6478
):
65-
def __init__(self, tracer, filter_=None):
79+
def __init__(
80+
self, tracer, filter_=None, request_hook=None, response_hook=None
81+
):
6682
self._tracer = tracer
6783
self._filter = filter_
84+
self._request_hook = request_hook
85+
self._response_hook = response_hook
6886

6987
def _start_span(self, method, **kwargs):
7088
service, meth = method.lstrip("/").split("/", 1)
@@ -99,6 +117,8 @@ def _trace_result(self, span, rpc_info, result):
99117
if isinstance(result, tuple):
100118
response = result[0]
101119
rpc_info.response = response
120+
if self._response_hook:
121+
self._call_response_hook(span, response)
102122
span.end()
103123
return result
104124

@@ -127,7 +147,8 @@ def _intercept(self, request, metadata, client_info, invoker):
127147
timeout=client_info.timeout,
128148
request=request,
129149
)
130-
150+
if self._request_hook:
151+
self._call_request_hook(span, request)
131152
result = invoker(request, metadata)
132153
except Exception as exc:
133154
if isinstance(exc, grpc.RpcError):
@@ -148,6 +169,16 @@ def _intercept(self, request, metadata, client_info, invoker):
148169
span.end()
149170
return self._trace_result(span, rpc_info, result)
150171

172+
def _call_request_hook(self, span, request):
173+
if not callable(self._request_hook):
174+
return
175+
_safe_invoke(self._request_hook, span, request)
176+
177+
def _call_response_hook(self, span, response):
178+
if not callable(self._response_hook):
179+
return
180+
_safe_invoke(self._response_hook, span, response)
181+
151182
def intercept_unary(self, request, metadata, client_info, invoker):
152183
if self._filter is not None and not self._filter(client_info):
153184
return invoker(request, metadata)

0 commit comments

Comments
 (0)