|
3 | 3 | from sentry_sdk._types import TYPE_CHECKING
|
4 | 4 |
|
5 | 5 | if TYPE_CHECKING:
|
6 |
| - from typing import Iterator, Any, TypeVar, Callable |
7 |
| - |
8 |
| - F = TypeVar("F", bound=Callable[..., Any]) |
| 6 | + from typing import Iterator, Any, Iterable, List, Optional, Callable |
| 7 | + from sentry_sdk.tracing import Span |
9 | 8 |
|
10 | 9 | from sentry_sdk._functools import wraps
|
11 | 10 | from sentry_sdk.hub import Hub
|
12 | 11 | from sentry_sdk.integrations import DidNotEnable, Integration
|
13 | 12 | from sentry_sdk.utils import logger, capture_internal_exceptions
|
14 | 13 |
|
15 | 14 | try:
|
16 |
| - from openai.types.chat import ChatCompletionChunk |
17 |
| - from openai.resources.chat.completions import Completions |
18 |
| - from openai.resources import Embeddings |
| 15 | + from openai.types.chat import ChatCompletionChunk # type: ignore |
| 16 | + from openai.resources.chat.completions import Completions # type: ignore |
| 17 | + from openai.resources import Embeddings # type: ignore |
| 18 | + |
| 19 | + if TYPE_CHECKING: |
| 20 | + from openai.types.chat import ChatCompletionMessageParam |
19 | 21 | except ImportError:
|
20 | 22 | raise DidNotEnable("OpenAI not installed")
|
21 | 23 |
|
22 | 24 | try:
|
23 |
| - import tiktoken |
| 25 | + import tiktoken # type: ignore |
24 | 26 |
|
25 | 27 | enc = tiktoken.get_encoding("cl100k_base")
|
26 | 28 |
|
@@ -51,14 +53,15 @@ class OpenAIIntegration(Integration):
|
51 | 53 |
|
52 | 54 | @staticmethod
|
53 | 55 | def setup_once():
|
54 |
| - # TODO minimum version |
| 56 | + # type: () -> None |
55 | 57 | Completions.create = _wrap_chat_completion_create(Completions.create)
|
56 | 58 | Embeddings.create = _wrap_enbeddings_create(Embeddings.create)
|
57 | 59 |
|
58 | 60 |
|
59 | 61 | def _calculate_chat_completion_usage(
|
60 | 62 | messages, response, span, streaming_message_responses=None
|
61 | 63 | ):
|
| 64 | + # type: (Iterable[ChatCompletionMessageParam], Any, Span, Optional[List[str]]) -> None |
62 | 65 | completion_tokens = 0
|
63 | 66 | prompt_tokens = 0
|
64 | 67 | total_tokens = 0
|
@@ -104,7 +107,7 @@ def _calculate_chat_completion_usage(
|
104 | 107 |
|
105 | 108 |
|
106 | 109 | def _wrap_chat_completion_create(f):
|
107 |
| - # type: (F) -> F |
| 110 | + # type: (Callable[..., Any]) -> Callable[..., Any] |
108 | 111 | @wraps(f)
|
109 | 112 | def new_chat_completion(*args, **kwargs):
|
110 | 113 | # type: (*Any, **Any) -> Any
|
@@ -180,10 +183,11 @@ def new_iterator() -> Iterator[ChatCompletionChunk]:
|
180 | 183 |
|
181 | 184 |
|
182 | 185 | def _wrap_enbeddings_create(f):
|
183 |
| - # type: (F) -> F |
| 186 | + # type: (Callable[..., Any]) -> Callable[..., Any] |
184 | 187 |
|
185 | 188 | @wraps(f)
|
186 | 189 | def new_embeddings_create(*args, **kwargs):
|
| 190 | + # type: (*Any, **Any) -> Any |
187 | 191 | hub = Hub.current
|
188 | 192 | integration = hub.get_integration(OpenAIIntegration)
|
189 | 193 | if integration is None:
|
|
0 commit comments