Skip to content

Commit 9f2ca27

Browse files
authored
Merge pull request #697
Fix/langchain_nvidia_ai_endpoints patch
2 parents dc509f6 + 88728ab commit 9f2ca27

File tree

2 files changed

+34
-139
lines changed

2 files changed

+34
-139
lines changed

nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py

Lines changed: 18 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -14,66 +14,22 @@
1414
# limitations under the License.
1515

1616
import logging
17-
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type
17+
from functools import wraps
18+
from typing import Any, List, Optional
1819

19-
import pkg_resources
2020
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
2121
from langchain_core.language_models.chat_models import generate_from_stream
22-
from langchain_core.messages import (
23-
AIMessageChunk,
24-
BaseMessage,
25-
BaseMessageChunk,
26-
ChatMessage,
27-
ChatMessageChunk,
28-
FunctionMessageChunk,
29-
HumanMessageChunk,
30-
SystemMessageChunk,
31-
ToolMessageChunk,
32-
)
33-
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
22+
from langchain_core.messages import BaseMessage
23+
from langchain_core.outputs import ChatResult
3424
from langchain_core.pydantic_v1 import Field
35-
from langchain_nvidia_ai_endpoints import ChatNVIDIA
36-
from packaging import version
25+
from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
3726

3827
log = logging.getLogger(__name__)
3928

4029

41-
def _convert_delta_to_message_chunk(
42-
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
43-
) -> BaseMessageChunk:
44-
role = _dict.get("role")
45-
content = _dict.get("content") or ""
46-
additional_kwargs: Dict = {}
47-
if _dict.get("function_call"):
48-
function_call = dict(_dict["function_call"])
49-
if "name" in function_call and function_call["name"] is None:
50-
function_call["name"] = ""
51-
additional_kwargs["function_call"] = function_call
52-
if _dict.get("tool_calls"):
53-
additional_kwargs["tool_calls"] = _dict["tool_calls"]
54-
55-
if role == "user" or default_class == HumanMessageChunk:
56-
return HumanMessageChunk(content=content)
57-
elif role == "assistant" or default_class == AIMessageChunk:
58-
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
59-
elif role == "system" or default_class == SystemMessageChunk:
60-
return SystemMessageChunk(content=content)
61-
elif role == "function" or default_class == FunctionMessageChunk:
62-
return FunctionMessageChunk(content=content, name=_dict["name"])
63-
elif role == "tool" or default_class == ToolMessageChunk:
64-
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
65-
elif role or default_class == ChatMessageChunk:
66-
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
67-
else:
68-
return default_class(content=content) # type: ignore[call-arg]
69-
70-
71-
class PatchedChatNVIDIAV1(ChatNVIDIA):
72-
streaming: bool = Field(
73-
default=False, description="Whether to use streaming or not"
74-
)
75-
76-
def _generate(
30+
def stream_decorator(func):
31+
@wraps(func)
32+
def wrapper(
7733
self,
7834
messages: List[BaseMessage],
7935
stop: Optional[List[str]] = None,
@@ -87,105 +43,30 @@ def _generate(
8743
messages, stop=stop, run_manager=run_manager, **kwargs
8844
)
8945
return generate_from_stream(stream_iter)
90-
inputs = self._custom_preprocess(messages)
91-
payload = self._get_payload(inputs=inputs, stop=stop, stream=False, **kwargs)
92-
response = self._client.client.get_req(payload=payload)
93-
responses, _ = self._client.client.postprocess(response)
94-
self._set_callback_out(responses, run_manager)
95-
message = ChatMessage(**self._custom_postprocess(responses))
96-
generation = ChatGeneration(message=message)
97-
return ChatResult(generations=[generation], llm_output=responses)
46+
else:
47+
return func(self, messages, stop, run_manager, **kwargs)
9848

99-
def _stream(
100-
self,
101-
messages: List[BaseMessage],
102-
stop: Optional[Sequence[str]] = None,
103-
run_manager: Optional[CallbackManagerForLLMRun] = None,
104-
**kwargs: Any,
105-
) -> Iterator[ChatGenerationChunk]:
106-
"""Allows streaming to model!"""
107-
inputs = self._custom_preprocess(messages)
108-
payload = self._get_payload(inputs=inputs, stop=stop, stream=True, **kwargs)
109-
default_chunk_class = AIMessageChunk
110-
for response in self._client.client.get_req_stream(payload=payload):
111-
self._set_callback_out(response, run_manager)
112-
chunk = _convert_delta_to_message_chunk(response, default_chunk_class)
113-
default_chunk_class = chunk.__class__
114-
cg_chunk = ChatGenerationChunk(message=chunk)
115-
if run_manager:
116-
run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk)
117-
yield cg_chunk
49+
return wrapper
11850

11951

120-
class PatchedChatNVIDIAV2(ChatNVIDIA):
52+
# NOTE: this needs to have the same name as the original class,
53+
# otherwise, there's a check inside `langchain-nvidia-ai-endpoints` that will fail.
54+
class ChatNVIDIA(ChatNVIDIAOriginal):
12155
streaming: bool = Field(
12256
default=False, description="Whether to use streaming or not"
12357
)
12458

59+
@stream_decorator
12560
def _generate(
12661
self,
12762
messages: List[BaseMessage],
12863
stop: Optional[List[str]] = None,
12964
run_manager: Optional[CallbackManagerForLLMRun] = None,
130-
stream: Optional[bool] = None,
13165
**kwargs: Any,
13266
) -> ChatResult:
133-
should_stream = stream if stream is not None else self.streaming
134-
if should_stream:
135-
stream_iter = self._stream(
136-
messages, stop=stop, run_manager=run_manager, **kwargs
137-
)
138-
return generate_from_stream(stream_iter)
139-
inputs = [
140-
_nv_vlm_adjust_input(message)
141-
for message in [convert_message_to_dict(message) for message in messages]
142-
]
143-
payload = self._get_payload(inputs=inputs, stop=stop, stream=False, **kwargs)
144-
response = self._client.client.get_req(payload=payload)
145-
responses, _ = self._client.client.postprocess(response)
146-
self._set_callback_out(responses, run_manager)
147-
parsed_response = self._custom_postprocess(responses, streaming=False)
148-
# for pre 0.2 compatibility w/ ChatMessage
149-
# ChatMessage had a role property that was not present in AIMessage
150-
parsed_response.update({"role": "assistant"})
151-
generation = ChatGeneration(message=AIMessage(**parsed_response))
152-
return ChatResult(generations=[generation], llm_output=responses)
153-
154-
155-
class ChatNVIDIAFactory:
156-
RANGE1 = (version.parse("0.1.0"), version.parse("0.2.0"))
157-
RANGE2 = (version.parse("0.2.0"), version.parse("0.3.0"))
158-
159-
@staticmethod
160-
def get_package_version(package_name):
161-
return version.parse(pkg_resources.get_distribution(package_name).version)
162-
163-
@staticmethod
164-
def is_version_in_range(version, range):
165-
return range[0] <= version < range[1]
166-
167-
@classmethod
168-
def create(cls):
169-
current_version = cls.get_package_version("langchain_nvidia_ai_endpoints")
170-
171-
if cls.is_version_in_range(current_version, cls.RANGE1):
172-
log.debug(
173-
f"Using pathed version of ChatNVIDIA for version {current_version}"
174-
)
175-
return PatchedChatNVIDIAV1
176-
elif cls.is_version_in_range(current_version, cls.RANGE2):
177-
log.debug(
178-
f"Using pathed version of ChatNVIDIA for version {current_version}"
179-
)
180-
from langchain_community.adapters.openai import convert_message_to_dict
181-
from langchain_nvidia_ai_endpoints.chat_models import _nv_vlm_adjust_input
182-
183-
return PatchedChatNVIDIAV2
184-
else:
185-
return ChatNVIDIA
186-
187-
188-
ChatNVIDIA = ChatNVIDIAFactory.create()
67+
return super()._generate(
68+
messages=messages, stop=stop, run_manager=run_manager, **kwargs
69+
)
18970

19071

19172
__all__ = ["ChatNVIDIA"]

nemoguardrails/llm/providers/providers.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
2121
Additional providers can be registered using the `register_llm_provider` function.
2222
"""
23+
2324
import asyncio
2425
import logging
26+
from importlib.metadata import PackageNotFoundError, version
2527
from typing import Any, Dict, List, Optional, Type
2628

2729
from langchain.base_language import BaseLanguageModel
@@ -33,6 +35,7 @@
3335
from langchain.schema.output import GenerationChunk
3436
from langchain_community import llms
3537
from langchain_community.llms import HuggingFacePipeline
38+
from packaging import version as pkg_version
3639

3740
from nemoguardrails.rails.llm.config import Model
3841

@@ -240,11 +243,18 @@ def get_llm_provider(model_config: Model) -> Type[BaseLanguageModel]:
240243
)
241244
elif model_config.engine == "nvidia_ai_endpoints" or model_config.engine == "nim":
242245
try:
243-
from langchain_nvidia_ai_endpoints import ChatNVIDIA
244-
245246
from ._langchain_nvidia_ai_endpoints_patch import ChatNVIDIA
246247

248+
# Check the version
249+
package_version = version("langchain_nvidia_ai_endpoints")
250+
251+
if _parse_version(package_version) < (0, 2, 0):
252+
raise ValueError(
253+
"langchain_nvidia_ai_endpoints version must be 0.2.0 or above."
254+
" Please upgrade it with `pip install langchain-nvidia-ai-endpoints --upgrade`."
255+
)
247256
return ChatNVIDIA
257+
248258
except ImportError:
249259
raise ImportError(
250260
"Could not import langchain_nvidia_ai_endpoints, please install it with "
@@ -271,3 +281,7 @@ def get_llm_provider(model_config: Model) -> Type[BaseLanguageModel]:
271281
def get_llm_provider_names() -> List[str]:
272282
"""Returns the list of supported LLM providers."""
273283
return list(sorted(list(_providers.keys())))
284+
285+
286+
def _parse_version(version_str):
287+
return tuple(map(int, (version_str.split("."))))

0 commit comments

Comments
 (0)