Skip to content

Commit 014e712

Browse files
committed
VertexAI handle streaming requests
WIP using shared context manager Properly implement uninstrument Shared code with a contextmanager tmp
1 parent ad2fe81 commit 014e712

14 files changed

+1611
-51
lines changed

instrumentation-genai/opentelemetry-instrumentation-vertexai/src/opentelemetry/instrumentation/vertexai/__init__.py

+35-18
Original file line numberDiff line numberDiff line change
@@ -39,42 +39,55 @@
3939
---
4040
"""
4141

42+
from __future__ import annotations
43+
4244
from typing import Any, Collection
4345

4446
from wrapt import (
45-
wrap_function_wrapper, # type: ignore[reportUnknownVariableType]
47+
wrap_function_wrapper, # pyright: ignore[reportUnknownVariableType]
4648
)
4749

4850
from opentelemetry._events import get_event_logger
4951
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
5052
from opentelemetry.instrumentation.utils import unwrap
5153
from opentelemetry.instrumentation.vertexai.package import _instruments
52-
from opentelemetry.instrumentation.vertexai.patch import (
53-
generate_content_create,
54-
)
54+
from opentelemetry.instrumentation.vertexai.patch import MethodWrappers
5555
from opentelemetry.instrumentation.vertexai.utils import is_content_enabled
5656
from opentelemetry.semconv.schemas import Schemas
5757
from opentelemetry.trace import get_tracer
5858

5959

60-
def _client_classes():
60+
def _methods_to_wrap(
61+
method_wrappers: MethodWrappers,
62+
):
6163
# This import is very slow, do it lazily in case instrument() is not called
62-
6364
# pylint: disable=import-outside-toplevel
64-
from google.cloud.aiplatform_v1.services.prediction_service import (
65-
client,
66-
)
65+
from google.cloud.aiplatform_v1.services.prediction_service import client
6766
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
6867
client as client_v1beta1,
6968
)
7069

71-
return (
70+
for client_class in (
7271
client.PredictionServiceClient,
7372
client_v1beta1.PredictionServiceClient,
74-
)
73+
):
74+
yield (
75+
client_class,
76+
client_class.generate_content.__name__, # pyright: ignore[reportUnknownMemberType]
77+
method_wrappers.generate_content,
78+
)
79+
yield (
80+
client_class,
81+
client_class.stream_generate_content.__name__, # pyright: ignore[reportUnknownMemberType]
82+
method_wrappers.stream_generate_content,
83+
)
7584

7685

7786
class VertexAIInstrumentor(BaseInstrumentor):
87+
def __init__(self) -> None:
88+
super().__init__()
89+
self._methods_to_unwrap: list[tuple[Any, str]] = []
90+
7891
def instrumentation_dependencies(self) -> Collection[str]:
7992
return _instruments
8093

@@ -95,15 +108,19 @@ def _instrument(self, **kwargs: Any):
95108
event_logger_provider=event_logger_provider,
96109
)
97110

98-
for client_class in _client_classes():
111+
method_wrappers = MethodWrappers(
112+
tracer, event_logger, is_content_enabled()
113+
)
114+
for client_class, method_name, wrapper in _methods_to_wrap(
115+
method_wrappers
116+
):
99117
wrap_function_wrapper(
100118
client_class,
101-
name="generate_content",
102-
wrapper=generate_content_create(
103-
tracer, event_logger, is_content_enabled()
104-
),
119+
name=method_name,
120+
wrapper=wrapper,
105121
)
122+
self._methods_to_unwrap.append((client_class, method_name))
106123

107124
def _uninstrument(self, **kwargs: Any) -> None:
108-
for client_class in _client_classes():
109-
unwrap(client_class, "generate_content")
125+
for client_class, method_name in self._methods_to_unwrap:
126+
unwrap(client_class, method_name)

instrumentation-genai/opentelemetry-instrumentation-vertexai/src/opentelemetry/instrumentation/vertexai/events.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,11 @@ def choice_event(
161161
https://github.com/open-telemetry/semantic-conventions/blob/v1.28.0/docs/gen-ai/gen-ai-events.md#event-gen_aichoice
162162
"""
163163
body: dict[str, AnyValue] = {
164-
"finish_reason": finish_reason,
165164
"index": index,
166165
"message": _asdict_filter_nulls(message),
167166
}
167+
if finish_reason:
168+
body["finish_reason"] = finish_reason
168169

169170
tool_calls_list = [
170171
_asdict_filter_nulls(tool_call) for tool_call in tool_calls

instrumentation-genai/opentelemetry-instrumentation-vertexai/src/opentelemetry/instrumentation/vertexai/patch.py

+77-27
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414

1515
from __future__ import annotations
1616

17+
from contextlib import contextmanager
1718
from typing import (
1819
TYPE_CHECKING,
1920
Any,
2021
Callable,
22+
Iterable,
2123
MutableSequence,
2224
)
2325

@@ -87,17 +89,17 @@ def _extract_params(
8789
)
8890

8991

90-
def generate_content_create(
91-
tracer: Tracer, event_logger: EventLogger, capture_content: bool
92-
):
93-
"""Wrap the `generate_content` method of the `GenerativeModel` class to trace it."""
92+
class MethodWrappers:
93+
def __init__(
94+
self, tracer: Tracer, event_logger: EventLogger, capture_content: bool
95+
) -> None:
96+
self.tracer = tracer
97+
self.event_logger = event_logger
98+
self.capture_content = capture_content
9499

95-
def traced_method(
96-
wrapped: Callable[
97-
...,
98-
prediction_service.GenerateContentResponse
99-
| prediction_service_v1beta1.GenerateContentResponse,
100-
],
100+
@contextmanager
101+
def _with_instrumentation(
102+
self,
101103
instance: client.PredictionServiceClient
102104
| client_v1beta1.PredictionServiceClient,
103105
args: Any,
@@ -111,32 +113,80 @@ def traced_method(
111113
}
112114

113115
span_name = get_span_name(span_attributes)
114-
with tracer.start_as_current_span(
116+
117+
with self.tracer.start_as_current_span(
115118
name=span_name,
116119
kind=SpanKind.CLIENT,
117120
attributes=span_attributes,
118121
) as span:
119122
for event in request_to_events(
120-
params=params, capture_content=capture_content
123+
params=params, capture_content=self.capture_content
121124
):
122-
event_logger.emit(event)
125+
self.event_logger.emit(event)
123126

124127
# TODO: set error.type attribute
125128
# https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/gen-ai-spans.md
126-
response = wrapped(*args, **kwargs)
127-
# TODO: handle streaming
128-
# if is_streaming(kwargs):
129-
# return StreamWrapper(
130-
# result, span, event_logger, capture_content
131-
# )
132-
133-
if span.is_recording():
134-
span.set_attributes(get_genai_response_attributes(response))
135-
for event in response_to_events(
136-
response=response, capture_content=capture_content
137-
):
138-
event_logger.emit(event)
139129

130+
def handle_response(
131+
response: prediction_service.GenerateContentResponse
132+
| prediction_service_v1beta1.GenerateContentResponse,
133+
) -> None:
134+
if span.is_recording():
135+
# When streaming, this is called multiple times so attributes would be
136+
# overwritten. In practice, it looks the API only returns the interesting
137+
# attributes on the last streamed response. However, I couldn't find
138+
# documentation for this and setting attributes shouldn't be too expensive.
139+
span.set_attributes(
140+
get_genai_response_attributes(response)
141+
)
142+
143+
for event in response_to_events(
144+
response=response, capture_content=self.capture_content
145+
):
146+
self.event_logger.emit(event)
147+
148+
yield handle_response
149+
150+
def generate_content(
151+
self,
152+
wrapped: Callable[
153+
...,
154+
prediction_service.GenerateContentResponse
155+
| prediction_service_v1beta1.GenerateContentResponse,
156+
],
157+
instance: client.PredictionServiceClient
158+
| client_v1beta1.PredictionServiceClient,
159+
args: Any,
160+
kwargs: Any,
161+
) -> (
162+
prediction_service.GenerateContentResponse
163+
| prediction_service_v1beta1.GenerateContentResponse
164+
):
165+
with self._with_instrumentation(
166+
instance, args, kwargs
167+
) as handle_response:
168+
response = wrapped(*args, **kwargs)
169+
handle_response(response)
140170
return response
141171

142-
return traced_method
172+
def stream_generate_content(
173+
self,
174+
wrapped: Callable[
175+
...,
176+
Iterable[prediction_service.GenerateContentResponse]
177+
| Iterable[prediction_service_v1beta1.GenerateContentResponse],
178+
],
179+
instance: client.PredictionServiceClient
180+
| client_v1beta1.PredictionServiceClient,
181+
args: Any,
182+
kwargs: Any,
183+
) -> Iterable[
184+
prediction_service.GenerateContentResponse
185+
| prediction_service_v1beta1.GenerateContentResponse,
186+
]:
187+
with self._with_instrumentation(
188+
instance, args, kwargs
189+
) as handle_response:
190+
for response in wrapped(*args, **kwargs):
191+
handle_response(response)
192+
yield response

instrumentation-genai/opentelemetry-instrumentation-vertexai/src/opentelemetry/instrumentation/vertexai/utils.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -330,10 +330,9 @@ def _map_finish_reason(
330330
| content_v1beta1.Candidate.FinishReason,
331331
) -> FinishReason | str:
332332
EnumType = type(finish_reason) # pylint: disable=invalid-name
333-
if (
334-
finish_reason is EnumType.FINISH_REASON_UNSPECIFIED
335-
or finish_reason is EnumType.OTHER
336-
):
333+
if finish_reason is EnumType.FINISH_REASON_UNSPECIFIED:
334+
return ""
335+
if finish_reason is EnumType.OTHER:
337336
return "error"
338337
if finish_reason is EnumType.STOP:
339338
return "stop"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
interactions:
2+
- request:
3+
body: |-
4+
{
5+
"contents": [
6+
{
7+
"role": "user",
8+
"parts": [
9+
{
10+
"text": "Get weather details in New Delhi and San Francisco?"
11+
}
12+
]
13+
}
14+
],
15+
"tools": [
16+
{
17+
"functionDeclarations": [
18+
{
19+
"name": "get_current_weather",
20+
"description": "Get the current weather in a given location",
21+
"parameters": {
22+
"type": 6,
23+
"properties": {
24+
"location": {
25+
"type": 1,
26+
"description": "The location for which to get the weather. It can be a city name, a city name and state, or a zip code. Examples: 'San Francisco', 'San Francisco, CA', '95616', etc."
27+
}
28+
},
29+
"propertyOrdering": [
30+
"location"
31+
]
32+
}
33+
}
34+
]
35+
}
36+
]
37+
}
38+
headers:
39+
Accept:
40+
- '*/*'
41+
Accept-Encoding:
42+
- gzip, deflate
43+
Connection:
44+
- keep-alive
45+
Content-Length:
46+
- '824'
47+
Content-Type:
48+
- application/json
49+
User-Agent:
50+
- python-requests/2.32.3
51+
method: POST
52+
uri: https://us-central1-aiplatform.googleapis.com/v1/projects/fake-project/locations/us-central1/publishers/google/models/gemini-1.5-flash-002:streamGenerateContent?%24alt=json%3Benum-encoding%3Dint
53+
response:
54+
body:
55+
string: |-
56+
[
57+
{
58+
"candidates": [
59+
{
60+
"content": {
61+
"role": "model",
62+
"parts": [
63+
{
64+
"functionCall": {
65+
"name": "get_current_weather",
66+
"args": {
67+
"location": "New Delhi"
68+
}
69+
}
70+
},
71+
{
72+
"functionCall": {
73+
"name": "get_current_weather",
74+
"args": {
75+
"location": "San Francisco"
76+
}
77+
}
78+
}
79+
]
80+
},
81+
"finishReason": 1
82+
}
83+
],
84+
"usageMetadata": {
85+
"promptTokenCount": 72,
86+
"candidatesTokenCount": 16,
87+
"totalTokenCount": 88,
88+
"promptTokensDetails": [
89+
{
90+
"modality": 1,
91+
"tokenCount": 72
92+
}
93+
],
94+
"candidatesTokensDetails": [
95+
{
96+
"modality": 1,
97+
"tokenCount": 16
98+
}
99+
]
100+
},
101+
"modelVersion": "gemini-1.5-flash-002",
102+
"createTime": "2025-03-05T04:44:12.226326Z",
103+
"responseId": "nNbHZ5boDZeTmecP49qwuQU"
104+
}
105+
]
106+
headers:
107+
Content-Type:
108+
- application/json; charset=UTF-8
109+
Transfer-Encoding:
110+
- chunked
111+
Vary:
112+
- Origin
113+
- X-Origin
114+
- Referer
115+
content-length:
116+
- '985'
117+
status:
118+
code: 200
119+
message: OK
120+
version: 1

0 commit comments

Comments
 (0)