Skip to content

Commit 8bb50af

Browse files
authored
Merge pull request #367 from NVIDIA/fix/qa-fixes-6
Fix LangChain warnings and bug affecting Llama-2 example.
2 parents 0be7a72 + fedd0ed commit 8bb50af

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

nemoguardrails/llm/helpers.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@ def _modify_instance_kwargs(self):
6666
"""
6767

6868
if hasattr(llm_instance, "model_kwargs"):
69-
llm_instance.model_kwargs["temperature"] = self.temperature
70-
llm_instance.model_kwargs["streaming"] = self.streaming
69+
if isinstance(llm_instance.model_kwargs, dict):
70+
llm_instance.model_kwargs["temperature"] = self.temperature
71+
llm_instance.model_kwargs["streaming"] = self.streaming
7172

7273
def _call(
7374
self,

nemoguardrails/llm/providers/providers.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def _call(
6969
)
7070

7171
# Streaming for NeMo Guardrails is not supported in sync calls.
72-
if self.model_kwargs.get("streaming"):
72+
if self.model_kwargs and self.model_kwargs.get("streaming"):
7373
raise Exception(
7474
"Streaming mode not supported for HuggingFacePipeline in NeMo Guardrails!"
7575
)
@@ -100,7 +100,7 @@ async def _acall(
100100
)
101101

102102
# Handle streaming, if the flag is set
103-
if self.model_kwargs.get("streaming"):
103+
if self.model_kwargs and self.model_kwargs.get("streaming"):
104104
# Retrieve the streamer object, needs to be set in model_kwargs
105105
streamer = self.model_kwargs.get("streamer")
106106
if not streamer:
@@ -153,7 +153,18 @@ async def _acall(self, *args, **kwargs):
153153

154154
def discover_langchain_providers():
155155
"""Automatically discover all LLM providers from LangChain."""
156-
_providers.update(llms.type_to_cls_dict)
156+
# To deal with deprecated stuff and avoid warnings, we compose the type_to_cls_dict here
157+
if hasattr(llms, "get_type_to_cls_dict"):
158+
type_to_cls_dict = {
159+
k: v()
160+
for k, v in llms.get_type_to_cls_dict().items()
161+
# Exclude deprecated ones
162+
if k not in ["mlflow-chat", "databricks-chat"]
163+
}
164+
else:
165+
type_to_cls_dict = llms.type_to_cls_dict
166+
167+
_providers.update(type_to_cls_dict)
157168

158169
# We make sure we have OpenAI from the right package.
159170
if "openai" in _providers:

0 commit comments

Comments
 (0)