File tree 2 files changed +4
-4
lines changed
nemoguardrails/llm/providers
2 files changed +4
-4
lines changed Original file line number Diff line number Diff line change 22
22
from langchain_core .messages import BaseMessage
23
23
from langchain_core .outputs import ChatResult
24
24
from langchain_core .pydantic_v1 import Field
25
- from langchain_nvidia_ai_endpoints import ChatNVIDIA
25
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
26
26
27
27
log = logging .getLogger (__name__ )
28
28
@@ -49,7 +49,9 @@ def wrapper(
49
49
return wrapper
50
50
51
51
52
- class ChatNVIDIA (ChatNVIDIA ):
52
+ # NOTE: this needs to have the same name as the original class,
53
+ # otherwise, there's a check inside `langchain-nvidia-ai-endpoints` that will fail.
54
+ class ChatNVIDIA (ChatNVIDIAOriginal ):
53
55
streaming : bool = Field (
54
56
default = False , description = "Whether to use streaming or not"
55
57
)
Original file line number Diff line number Diff line change @@ -243,8 +243,6 @@ def get_llm_provider(model_config: Model) -> Type[BaseLanguageModel]:
243
243
)
244
244
elif model_config .engine == "nvidia_ai_endpoints" or model_config .engine == "nim" :
245
245
try :
246
- from langchain_nvidia_ai_endpoints import ChatNVIDIA
247
-
248
246
from ._langchain_nvidia_ai_endpoints_patch import ChatNVIDIA
249
247
250
248
# Check the version
You can’t perform that action at this time.
0 commit comments