14
14
# limitations under the License.
15
15
16
16
import logging
17
- from typing import Any , Dict , Iterator , List , Mapping , Optional , Sequence , Type
17
+ from functools import wraps
18
+ from typing import Any , List , Optional
18
19
19
- import pkg_resources
20
20
from langchain_core .callbacks .manager import CallbackManagerForLLMRun
21
21
from langchain_core .language_models .chat_models import generate_from_stream
22
- from langchain_core .messages import (
23
- AIMessageChunk ,
24
- BaseMessage ,
25
- BaseMessageChunk ,
26
- ChatMessage ,
27
- ChatMessageChunk ,
28
- FunctionMessageChunk ,
29
- HumanMessageChunk ,
30
- SystemMessageChunk ,
31
- ToolMessageChunk ,
32
- )
33
- from langchain_core .outputs import ChatGeneration , ChatGenerationChunk , ChatResult
22
+ from langchain_core .messages import BaseMessage
23
+ from langchain_core .outputs import ChatResult
34
24
from langchain_core .pydantic_v1 import Field
35
- from langchain_nvidia_ai_endpoints import ChatNVIDIA
36
- from packaging import version
25
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
37
26
38
27
log = logging .getLogger (__name__ )
39
28
40
29
41
- def _convert_delta_to_message_chunk (
42
- _dict : Mapping [str , Any ], default_class : Type [BaseMessageChunk ]
43
- ) -> BaseMessageChunk :
44
- role = _dict .get ("role" )
45
- content = _dict .get ("content" ) or ""
46
- additional_kwargs : Dict = {}
47
- if _dict .get ("function_call" ):
48
- function_call = dict (_dict ["function_call" ])
49
- if "name" in function_call and function_call ["name" ] is None :
50
- function_call ["name" ] = ""
51
- additional_kwargs ["function_call" ] = function_call
52
- if _dict .get ("tool_calls" ):
53
- additional_kwargs ["tool_calls" ] = _dict ["tool_calls" ]
54
-
55
- if role == "user" or default_class == HumanMessageChunk :
56
- return HumanMessageChunk (content = content )
57
- elif role == "assistant" or default_class == AIMessageChunk :
58
- return AIMessageChunk (content = content , additional_kwargs = additional_kwargs )
59
- elif role == "system" or default_class == SystemMessageChunk :
60
- return SystemMessageChunk (content = content )
61
- elif role == "function" or default_class == FunctionMessageChunk :
62
- return FunctionMessageChunk (content = content , name = _dict ["name" ])
63
- elif role == "tool" or default_class == ToolMessageChunk :
64
- return ToolMessageChunk (content = content , tool_call_id = _dict ["tool_call_id" ])
65
- elif role or default_class == ChatMessageChunk :
66
- return ChatMessageChunk (content = content , role = role ) # type: ignore[arg-type]
67
- else :
68
- return default_class (content = content ) # type: ignore[call-arg]
69
-
70
-
71
- class PatchedChatNVIDIAV1 (ChatNVIDIA ):
72
- streaming : bool = Field (
73
- default = False , description = "Whether to use streaming or not"
74
- )
75
-
76
- def _generate (
30
+ def stream_decorator (func ):
31
+ @wraps (func )
32
+ def wrapper (
77
33
self ,
78
34
messages : List [BaseMessage ],
79
35
stop : Optional [List [str ]] = None ,
@@ -87,105 +43,30 @@ def _generate(
87
43
messages , stop = stop , run_manager = run_manager , ** kwargs
88
44
)
89
45
return generate_from_stream (stream_iter )
90
- inputs = self ._custom_preprocess (messages )
91
- payload = self ._get_payload (inputs = inputs , stop = stop , stream = False , ** kwargs )
92
- response = self ._client .client .get_req (payload = payload )
93
- responses , _ = self ._client .client .postprocess (response )
94
- self ._set_callback_out (responses , run_manager )
95
- message = ChatMessage (** self ._custom_postprocess (responses ))
96
- generation = ChatGeneration (message = message )
97
- return ChatResult (generations = [generation ], llm_output = responses )
46
+ else :
47
+ return func (self , messages , stop , run_manager , ** kwargs )
98
48
99
- def _stream (
100
- self ,
101
- messages : List [BaseMessage ],
102
- stop : Optional [Sequence [str ]] = None ,
103
- run_manager : Optional [CallbackManagerForLLMRun ] = None ,
104
- ** kwargs : Any ,
105
- ) -> Iterator [ChatGenerationChunk ]:
106
- """Allows streaming to model!"""
107
- inputs = self ._custom_preprocess (messages )
108
- payload = self ._get_payload (inputs = inputs , stop = stop , stream = True , ** kwargs )
109
- default_chunk_class = AIMessageChunk
110
- for response in self ._client .client .get_req_stream (payload = payload ):
111
- self ._set_callback_out (response , run_manager )
112
- chunk = _convert_delta_to_message_chunk (response , default_chunk_class )
113
- default_chunk_class = chunk .__class__
114
- cg_chunk = ChatGenerationChunk (message = chunk )
115
- if run_manager :
116
- run_manager .on_llm_new_token (cg_chunk .text , chunk = cg_chunk )
117
- yield cg_chunk
49
+ return wrapper
118
50
119
51
120
- class PatchedChatNVIDIAV2 (ChatNVIDIA ):
52
+ # NOTE: this needs to have the same name as the original class,
53
+ # otherwise, there's a check inside `langchain-nvidia-ai-endpoints` that will fail.
54
+ class ChatNVIDIA (ChatNVIDIAOriginal ):
121
55
streaming : bool = Field (
122
56
default = False , description = "Whether to use streaming or not"
123
57
)
124
58
59
+ @stream_decorator
125
60
def _generate (
126
61
self ,
127
62
messages : List [BaseMessage ],
128
63
stop : Optional [List [str ]] = None ,
129
64
run_manager : Optional [CallbackManagerForLLMRun ] = None ,
130
- stream : Optional [bool ] = None ,
131
65
** kwargs : Any ,
132
66
) -> ChatResult :
133
- should_stream = stream if stream is not None else self .streaming
134
- if should_stream :
135
- stream_iter = self ._stream (
136
- messages , stop = stop , run_manager = run_manager , ** kwargs
137
- )
138
- return generate_from_stream (stream_iter )
139
- inputs = [
140
- _nv_vlm_adjust_input (message )
141
- for message in [convert_message_to_dict (message ) for message in messages ]
142
- ]
143
- payload = self ._get_payload (inputs = inputs , stop = stop , stream = False , ** kwargs )
144
- response = self ._client .client .get_req (payload = payload )
145
- responses , _ = self ._client .client .postprocess (response )
146
- self ._set_callback_out (responses , run_manager )
147
- parsed_response = self ._custom_postprocess (responses , streaming = False )
148
- # for pre 0.2 compatibility w/ ChatMessage
149
- # ChatMessage had a role property that was not present in AIMessage
150
- parsed_response .update ({"role" : "assistant" })
151
- generation = ChatGeneration (message = AIMessage (** parsed_response ))
152
- return ChatResult (generations = [generation ], llm_output = responses )
153
-
154
-
155
- class ChatNVIDIAFactory :
156
- RANGE1 = (version .parse ("0.1.0" ), version .parse ("0.2.0" ))
157
- RANGE2 = (version .parse ("0.2.0" ), version .parse ("0.3.0" ))
158
-
159
- @staticmethod
160
- def get_package_version (package_name ):
161
- return version .parse (pkg_resources .get_distribution (package_name ).version )
162
-
163
- @staticmethod
164
- def is_version_in_range (version , range ):
165
- return range [0 ] <= version < range [1 ]
166
-
167
- @classmethod
168
- def create (cls ):
169
- current_version = cls .get_package_version ("langchain_nvidia_ai_endpoints" )
170
-
171
- if cls .is_version_in_range (current_version , cls .RANGE1 ):
172
- log .debug (
173
- f"Using pathed version of ChatNVIDIA for version { current_version } "
174
- )
175
- return PatchedChatNVIDIAV1
176
- elif cls .is_version_in_range (current_version , cls .RANGE2 ):
177
- log .debug (
178
- f"Using pathed version of ChatNVIDIA for version { current_version } "
179
- )
180
- from langchain_community .adapters .openai import convert_message_to_dict
181
- from langchain_nvidia_ai_endpoints .chat_models import _nv_vlm_adjust_input
182
-
183
- return PatchedChatNVIDIAV2
184
- else :
185
- return ChatNVIDIA
186
-
187
-
188
- ChatNVIDIA = ChatNVIDIAFactory .create ()
67
+ return super ()._generate (
68
+ messages = messages , stop = stop , run_manager = run_manager , ** kwargs
69
+ )
189
70
190
71
191
72
__all__ = ["ChatNVIDIA" ]
0 commit comments