20
20
21
21
Additional providers can be registered using the `register_llm_provider` function.
22
22
"""
23
+
23
24
import asyncio
24
25
import logging
26
+ from importlib .metadata import PackageNotFoundError , version
25
27
from typing import Any , Dict , List , Optional , Type
26
28
27
29
from langchain .base_language import BaseLanguageModel
33
35
from langchain .schema .output import GenerationChunk
34
36
from langchain_community import llms
35
37
from langchain_community .llms import HuggingFacePipeline
38
+ from packaging import version as pkg_version
36
39
37
40
from nemoguardrails .rails .llm .config import Model
38
41
@@ -244,7 +247,16 @@ def get_llm_provider(model_config: Model) -> Type[BaseLanguageModel]:
244
247
245
248
from ._langchain_nvidia_ai_endpoints_patch import ChatNVIDIA
246
249
250
+ # Check the version
251
+ package_version = version ("langchain_nvidia_ai_endpoints" )
252
+
253
+ if _parse_version (package_version ) < (0 , 2 , 0 ):
254
+ raise ValueError (
255
+ "langchain_nvidia_ai_endpoints version must be 0.2.0 or above."
256
+ " Please upgrade it with `pip install langchain-nvidia-ai-endpoints --upgrade`."
257
+ )
247
258
return ChatNVIDIA
259
+
248
260
except ImportError :
249
261
raise ImportError (
250
262
"Could not import langchain_nvidia_ai_endpoints, please install it with "
@@ -271,3 +283,7 @@ def get_llm_provider(model_config: Model) -> Type[BaseLanguageModel]:
271
283
def get_llm_provider_names () -> List [str ]:
272
284
"""Returns the list of supported LLM providers."""
273
285
return list (sorted (list (_providers .keys ())))
286
+
287
+
288
+ def _parse_version (version_str ):
289
+ return tuple (map (int , (version_str .split ("." ))))
0 commit comments