Skip to content

Commit 88728ab

Browse files
committed
Remove unnecessary import and fix name conflict warning.
1 parent d26f49e commit 88728ab

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from langchain_core.messages import BaseMessage
2323
from langchain_core.outputs import ChatResult
2424
from langchain_core.pydantic_v1 import Field
25-
from langchain_nvidia_ai_endpoints import ChatNVIDIA
25+
from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
2626

2727
log = logging.getLogger(__name__)
2828

@@ -49,7 +49,9 @@ def wrapper(
4949
return wrapper
5050

5151

52-
class ChatNVIDIA(ChatNVIDIA):
52+
# NOTE: this needs to have the same name as the original class,
53+
# otherwise, there's a check inside `langchain-nvidia-ai-endpoints` that will fail.
54+
class ChatNVIDIA(ChatNVIDIAOriginal):
5355
streaming: bool = Field(
5456
default=False, description="Whether to use streaming or not"
5557
)

nemoguardrails/llm/providers/providers.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,6 @@ def get_llm_provider(model_config: Model) -> Type[BaseLanguageModel]:
243243
)
244244
elif model_config.engine == "nvidia_ai_endpoints" or model_config.engine == "nim":
245245
try:
246-
from langchain_nvidia_ai_endpoints import ChatNVIDIA
247-
248246
from ._langchain_nvidia_ai_endpoints_patch import ChatNVIDIA
249247

250248
# Check the version

0 commit comments

Comments
 (0)