From c161e60120f6c03385ffca818f1b9ad5fbc0e886 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Tue, 15 Apr 2025 11:17:43 +0200 Subject: [PATCH 01/11] add loras support --- .../inference/_providers/_common.py | 65 +++++++++++++++---- .../inference/_providers/black_forest_labs.py | 6 +- .../inference/_providers/fal_ai.py | 44 +++++++++---- .../inference/_providers/hf_inference.py | 25 ++++--- .../inference/_providers/hyperbolic.py | 12 +++- .../inference/_providers/nebius.py | 6 +- .../inference/_providers/novita.py | 5 +- .../inference/_providers/openai.py | 6 +- .../inference/_providers/replicate.py | 13 ++-- .../inference/_providers/together.py | 6 +- 10 files changed, 140 insertions(+), 48 deletions(-) diff --git a/src/huggingface_hub/inference/_providers/_common.py b/src/huggingface_hub/inference/_providers/_common.py index d59f3f859c..a512dfa0be 100644 --- a/src/huggingface_hub/inference/_providers/_common.py +++ b/src/huggingface_hub/inference/_providers/_common.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from functools import lru_cache from typing import Any, Dict, Optional, Union @@ -34,6 +35,13 @@ def filter_none(d: Dict[str, Any]) -> Dict[str, Any]: return {k: v for k, v in d.items() if v is not None} +@dataclass +class ProviderMappingInfo: + provider_id: str + adapter_weights_path: Optional[str] = None + hf_model_id: Optional[str] = None + + class TaskProviderHelper: """Base class for task-specific provider helpers.""" @@ -61,28 +69,30 @@ def prepare_request( api_key = self._prepare_api_key(api_key) # mapped model from HF model ID - mapped_model = self._prepare_mapped_model(model) + provider_mapping_info = self._prepare_mapped_model(model) # default HF headers + user headers (to customize in subclasses) headers = self._prepare_headers(headers, api_key) # routed URL if HF token, or direct URL (to customize in '_prepare_route' in subclasses) - url = self._prepare_url(api_key, mapped_model) + url = self._prepare_url(api_key, provider_mapping_info.provider_id) # prepare payload (to customize in subclasses) - payload = self._prepare_payload_as_dict(inputs, parameters, mapped_model=mapped_model) + payload = self._prepare_payload_as_dict(inputs, parameters, provider_mapping_info=provider_mapping_info) if payload is not None: payload = recursive_merge(payload, extra_payload or {}) # body data (to customize in subclasses) - data = self._prepare_payload_as_bytes(inputs, parameters, mapped_model, extra_payload) + data = self._prepare_payload_as_bytes(inputs, parameters, provider_mapping_info, extra_payload) # check if both payload and data are set and return if payload is not None and data is not None: raise ValueError("Both payload and data cannot be set in the same request.") if payload is None and data is None: raise ValueError("Either payload or data must be set in the request.") - return RequestParameters(url=url, task=self.task, model=mapped_model, json=payload, data=data, headers=headers) + return RequestParameters( + url=url, task=self.task, model=provider_mapping_info.provider_id, json=payload, data=data, headers=headers + ) def get_response( self, @@ -107,7 +117,7 @@ def _prepare_api_key(self, api_key: Optional[str]) -> str: ) return api_key - def _prepare_mapped_model(self, model: Optional[str]) -> str: + def _prepare_mapped_model(self, model: Optional[str]) -> ProviderMappingInfo: """Return the mapped model ID to use for the request. Usually not overwritten in subclasses.""" @@ -116,7 +126,9 @@ def _prepare_mapped_model(self, model: Optional[str]) -> str: # hardcoded mapping for local testing if HARDCODED_MODEL_ID_MAPPING.get(self.provider, {}).get(model): - return HARDCODED_MODEL_ID_MAPPING[self.provider][model] + return ProviderMappingInfo( + provider_id=HARDCODED_MODEL_ID_MAPPING[self.provider][model], + ) provider_mapping = _fetch_inference_provider_mapping(model).get(self.provider) if provider_mapping is None: @@ -132,7 +144,15 @@ def _prepare_mapped_model(self, model: Optional[str]) -> str: logger.warning( f"Model {model} is in staging mode for provider {self.provider}. Meant for test purposes only." ) - return provider_mapping.provider_id + if provider_mapping.adapter == "lora": + adapter_weights_path = _fetch_lora_weights_path(model) + return ProviderMappingInfo( + adapter_weights_path=adapter_weights_path, + provider_id=provider_mapping.provider_id, + hf_model_id=model, + ) + + return ProviderMappingInfo(provider_id=provider_mapping.provider_id) def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: """Return the headers to use for the request. @@ -168,7 +188,9 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: """ return "" - def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + ) -> Optional[Dict]: """Return the payload to use for the request, as a dict. Override this method in subclasses for customized payloads. @@ -177,7 +199,7 @@ def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: return None def _prepare_payload_as_bytes( - self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict] + self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo, extra_payload: Optional[Dict] ) -> Optional[bytes]: """Return the body to use for the request, as bytes. @@ -199,8 +221,10 @@ def __init__(self, provider: str, base_url: str): def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/chat/completions" - def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: - return {"messages": inputs, **filter_none(parameters), "model": mapped_model} + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + ) -> Optional[Dict]: + return {"messages": inputs, **filter_none(parameters), "model": provider_mapping_info.provider_id} class BaseTextGenerationTask(TaskProviderHelper): @@ -215,8 +239,10 @@ def __init__(self, provider: str, base_url: str): def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/completions" - def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: - return {"prompt": inputs, **filter_none(parameters), "model": mapped_model} + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + ) -> Optional[Dict]: + return {"prompt": inputs, **filter_none(parameters), "model": provider_mapping_info.provider_id} @lru_cache(maxsize=None) @@ -233,6 +259,17 @@ def _fetch_inference_provider_mapping(model: str) -> Dict: return provider_mapping +@lru_cache(maxsize=None) +def _fetch_lora_weights_path(model: str) -> str: + from huggingface_hub.hf_api import HfApi + + repo_files = HfApi().list_repo_files(model) + safetensors_files = [f for f in repo_files if f.endswith(".safetensors")] + if len(safetensors_files) != 1: + raise ValueError(f"Expected exactly one safetensors file in repo {model}, got {len(safetensors_files)}.") + return safetensors_files[0] + + def recursive_merge(dict1: Dict, dict2: Dict) -> Dict: return { **dict1, diff --git a/src/huggingface_hub/inference/_providers/black_forest_labs.py b/src/huggingface_hub/inference/_providers/black_forest_labs.py index 3a1ace19df..bdbfff8f1b 100644 --- a/src/huggingface_hub/inference/_providers/black_forest_labs.py +++ b/src/huggingface_hub/inference/_providers/black_forest_labs.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Optional, Union from huggingface_hub.inference._common import RequestParameters, _as_dict -from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none +from huggingface_hub.inference._providers._common import ProviderMappingInfo, TaskProviderHelper, filter_none from huggingface_hub.utils import logging from huggingface_hub.utils._http import get_session @@ -27,7 +27,9 @@ def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: def _prepare_route(self, mapped_model: str, api_key: str) -> str: return f"/v1/{mapped_model}" - def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + ) -> Optional[Dict]: parameters = filter_none(parameters) if "num_inference_steps" in parameters: parameters["steps"] = parameters.pop("num_inference_steps") diff --git a/src/huggingface_hub/inference/_providers/fal_ai.py b/src/huggingface_hub/inference/_providers/fal_ai.py index 5bdf28a0e7..21f48f55e7 100644 --- a/src/huggingface_hub/inference/_providers/fal_ai.py +++ b/src/huggingface_hub/inference/_providers/fal_ai.py @@ -4,8 +4,9 @@ from typing import Any, Dict, Optional, Union from urllib.parse import urlparse +from huggingface_hub.constants import ENDPOINT from huggingface_hub.inference._common import RequestParameters, _as_dict -from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none +from huggingface_hub.inference._providers._common import ProviderMappingInfo, TaskProviderHelper, filter_none from huggingface_hub.utils import get_session, hf_raise_for_status from huggingface_hub.utils.logging import get_logger @@ -34,7 +35,9 @@ class FalAIAutomaticSpeechRecognitionTask(FalAITask): def __init__(self): super().__init__("automatic-speech-recognition") - def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + ) -> Optional[Dict]: if isinstance(inputs, str) and inputs.startswith(("http://", "https://")): # If input is a URL, pass it directly audio_url = inputs @@ -61,14 +64,29 @@ class FalAITextToImageTask(FalAITask): def __init__(self): super().__init__("text-to-image") - def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: - parameters = filter_none(parameters) - if "width" in parameters and "height" in parameters: - parameters["image_size"] = { - "width": parameters.pop("width"), - "height": parameters.pop("height"), + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + ) -> Optional[Dict]: + payload: Dict[str, Any] = { + "prompt": inputs, + **filter_none(parameters), + } + if "width" in payload and "height" in payload: + payload["image_size"] = { + "width": payload.pop("width"), + "height": payload.pop("height"), } - return {"prompt": inputs, **parameters} + if provider_mapping_info.adapter_weights_path is not None: + lora_path = ( + f"{ENDPOINT}/" + f"{provider_mapping_info.hf_model_id}/resolve/main/" + f"{provider_mapping_info.adapter_weights_path}" + ) + payload["loras"] = [{"path": lora_path, "scale": 1}] + if provider_mapping_info.provider_id == "fal-ai/lora": + payload["model_name"] = "stabilityai/stable-diffusion-xl-base-1.0" + + return payload def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: url = _as_dict(response)["images"][0]["url"] @@ -79,7 +97,9 @@ class FalAITextToSpeechTask(FalAITask): def __init__(self): super().__init__("text-to-speech") - def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + ) -> Optional[Dict]: return {"lyrics": inputs, **filter_none(parameters)} def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: @@ -104,7 +124,9 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return f"/{mapped_model}?_subdomain=queue" return f"/{mapped_model}" - def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + ) -> Optional[Dict]: return {"prompt": inputs, **filter_none(parameters)} def get_response( diff --git a/src/huggingface_hub/inference/_providers/hf_inference.py b/src/huggingface_hub/inference/_providers/hf_inference.py index 7bb54bf6a5..2448345092 100644 --- a/src/huggingface_hub/inference/_providers/hf_inference.py +++ b/src/huggingface_hub/inference/_providers/hf_inference.py @@ -5,7 +5,7 @@ from huggingface_hub import constants from huggingface_hub.inference._common import _b64_encode, _open_as_binary -from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none +from huggingface_hub.inference._providers._common import ProviderMappingInfo, TaskProviderHelper, filter_none from huggingface_hub.utils import build_hf_headers, get_session, get_token, hf_raise_for_status @@ -23,9 +23,9 @@ def _prepare_api_key(self, api_key: Optional[str]) -> str: # special case: for HF Inference we allow not providing an API key return api_key or get_token() # type: ignore[return-value] - def _prepare_mapped_model(self, model: Optional[str]) -> str: + def _prepare_mapped_model(self, model: Optional[str]) -> ProviderMappingInfo: if model is not None and model.startswith(("http://", "https://")): - return model + return ProviderMappingInfo(provider_id=model) model_id = model if model is not None else _fetch_recommended_models().get(self.task) if model_id is None: raise ValueError( @@ -33,7 +33,7 @@ def _prepare_mapped_model(self, model: Optional[str]) -> str: " explicitly. Visit https://huggingface.co/tasks for more info." ) _check_supported_task(model_id, self.task) - return model_id + return ProviderMappingInfo(provider_id=model_id) def _prepare_url(self, api_key: str, mapped_model: str) -> str: # hf-inference provider can handle URLs (e.g. Inference Endpoints or TGI deployment) @@ -47,7 +47,9 @@ def _prepare_url(self, api_key: str, mapped_model: str) -> str: else f"{self.base_url}/models/{mapped_model}" ) - def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + ) -> Optional[Dict]: if isinstance(inputs, bytes): raise ValueError(f"Unexpected binary input for task {self.task}.") if isinstance(inputs, Path): @@ -56,11 +58,13 @@ def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: class HFInferenceBinaryInputTask(HFInferenceTask): - def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + ) -> Optional[Dict]: return None def _prepare_payload_as_bytes( - self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict] + self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo, extra_payload: Optional[Dict] ) -> Optional[bytes]: parameters = filter_none({k: v for k, v in parameters.items() if v is not None}) extra_payload = extra_payload or {} @@ -82,9 +86,12 @@ def _prepare_payload_as_bytes( class HFInferenceConversational(HFInferenceTask): def __init__(self): - super().__init__("conversational") + super().__init__("text-generation") - def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + ) -> Optional[Dict]: + mapped_model = provider_mapping_info.provider_id payload_model = parameters.get("model") or mapped_model if payload_model is None or payload_model.startswith(("http://", "https://")): diff --git a/src/huggingface_hub/inference/_providers/hyperbolic.py b/src/huggingface_hub/inference/_providers/hyperbolic.py index a317b192da..7bb970b9c7 100644 --- a/src/huggingface_hub/inference/_providers/hyperbolic.py +++ b/src/huggingface_hub/inference/_providers/hyperbolic.py @@ -2,7 +2,12 @@ from typing import Any, Dict, Optional, Union from huggingface_hub.inference._common import RequestParameters, _as_dict -from huggingface_hub.inference._providers._common import BaseConversationalTask, TaskProviderHelper, filter_none +from huggingface_hub.inference._providers._common import ( + BaseConversationalTask, + ProviderMappingInfo, + TaskProviderHelper, + filter_none, +) class HyperbolicTextToImageTask(TaskProviderHelper): @@ -12,7 +17,10 @@ def __init__(self): def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/images/generations" - def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + ) -> Optional[Dict]: + mapped_model = provider_mapping_info.provider_id parameters = filter_none(parameters) if "num_inference_steps" in parameters: parameters["steps"] = parameters.pop("num_inference_steps") diff --git a/src/huggingface_hub/inference/_providers/nebius.py b/src/huggingface_hub/inference/_providers/nebius.py index 12dd58250e..51a51f72fb 100644 --- a/src/huggingface_hub/inference/_providers/nebius.py +++ b/src/huggingface_hub/inference/_providers/nebius.py @@ -5,6 +5,7 @@ from huggingface_hub.inference._providers._common import ( BaseConversationalTask, BaseTextGenerationTask, + ProviderMappingInfo, TaskProviderHelper, filter_none, ) @@ -37,7 +38,10 @@ def __init__(self): def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/images/generations" - def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + ) -> Optional[Dict]: + mapped_model = provider_mapping_info.provider_id parameters = filter_none(parameters) if "guidance_scale" in parameters: parameters.pop("guidance_scale") diff --git a/src/huggingface_hub/inference/_providers/novita.py b/src/huggingface_hub/inference/_providers/novita.py index 3622c04372..f8134b9d6c 100644 --- a/src/huggingface_hub/inference/_providers/novita.py +++ b/src/huggingface_hub/inference/_providers/novita.py @@ -4,6 +4,7 @@ from huggingface_hub.inference._providers._common import ( BaseConversationalTask, BaseTextGenerationTask, + ProviderMappingInfo, TaskProviderHelper, filter_none, ) @@ -49,7 +50,9 @@ def __init__(self): def _prepare_route(self, mapped_model: str, api_key: str) -> str: return f"/v3/hf/{mapped_model}" - def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + ) -> Optional[Dict]: return {"prompt": inputs, **filter_none(parameters)} def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: diff --git a/src/huggingface_hub/inference/_providers/openai.py b/src/huggingface_hub/inference/_providers/openai.py index 72dadbb77b..912303f26a 100644 --- a/src/huggingface_hub/inference/_providers/openai.py +++ b/src/huggingface_hub/inference/_providers/openai.py @@ -1,6 +1,6 @@ from typing import Optional -from huggingface_hub.inference._providers._common import BaseConversationalTask +from huggingface_hub.inference._providers._common import BaseConversationalTask, ProviderMappingInfo class OpenAIConversationalTask(BaseConversationalTask): @@ -16,7 +16,7 @@ def _prepare_api_key(self, api_key: Optional[str]) -> str: ) return api_key - def _prepare_mapped_model(self, model: Optional[str]) -> str: + def _prepare_mapped_model(self, model: Optional[str]) -> ProviderMappingInfo: if model is None: raise ValueError("Please provide an OpenAI model ID, e.g. `gpt-4o` or `o1`.") - return model + return ProviderMappingInfo(provider_id=model) diff --git a/src/huggingface_hub/inference/_providers/replicate.py b/src/huggingface_hub/inference/_providers/replicate.py index 9aa69e0409..115ed5a779 100644 --- a/src/huggingface_hub/inference/_providers/replicate.py +++ b/src/huggingface_hub/inference/_providers/replicate.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Optional, Union from huggingface_hub.inference._common import RequestParameters, _as_dict -from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none +from huggingface_hub.inference._providers._common import ProviderMappingInfo, TaskProviderHelper, filter_none from huggingface_hub.utils import get_session @@ -23,7 +23,10 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/predictions" return f"/v1/models/{mapped_model}/predictions" - def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + ) -> Optional[Dict]: + mapped_model = provider_mapping_info.provider_id payload: Dict[str, Any] = {"input": {"prompt": inputs, **filter_none(parameters)}} if ":" in mapped_model: version = mapped_model.split(":", 1)[1] @@ -47,7 +50,9 @@ class ReplicateTextToSpeechTask(ReplicateTask): def __init__(self): super().__init__("text-to-speech") - def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: - payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, mapped_model) # type: ignore[assignment] + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + ) -> Optional[Dict]: + payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore[assignment] payload["input"]["text"] = payload["input"].pop("prompt") # rename "prompt" to "text" for TTS return payload diff --git a/src/huggingface_hub/inference/_providers/together.py b/src/huggingface_hub/inference/_providers/together.py index c2f0b422d0..3de2b36237 100644 --- a/src/huggingface_hub/inference/_providers/together.py +++ b/src/huggingface_hub/inference/_providers/together.py @@ -6,6 +6,7 @@ from huggingface_hub.inference._providers._common import ( BaseConversationalTask, BaseTextGenerationTask, + ProviderMappingInfo, TaskProviderHelper, filter_none, ) @@ -55,7 +56,10 @@ class TogetherTextToImageTask(TogetherTask): def __init__(self): super().__init__("text-to-image") - def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + ) -> Optional[Dict]: + mapped_model = provider_mapping_info.provider_id parameters = filter_none(parameters) if "num_inference_steps" in parameters: parameters["steps"] = parameters.pop("num_inference_steps") From ea59870e6526b706ef6ab34f116782c8f02a3684 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Tue, 15 Apr 2025 11:19:23 +0200 Subject: [PATCH 02/11] nit --- src/huggingface_hub/inference/_providers/hf_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/huggingface_hub/inference/_providers/hf_inference.py b/src/huggingface_hub/inference/_providers/hf_inference.py index 2448345092..8599ab4567 100644 --- a/src/huggingface_hub/inference/_providers/hf_inference.py +++ b/src/huggingface_hub/inference/_providers/hf_inference.py @@ -86,7 +86,7 @@ def _prepare_payload_as_bytes( class HFInferenceConversational(HFInferenceTask): def __init__(self): - super().__init__("text-generation") + super().__init__("conversational") def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo From 99e14a024ce048bc62a9464a27b578eb312e0ea9 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Tue, 15 Apr 2025 17:37:22 +0200 Subject: [PATCH 03/11] review suggestions --- .../inference/_providers/_common.py | 44 +++++++----- .../inference/_providers/hf_inference.py | 6 +- .../inference/_providers/openai.py | 4 +- tests/test_inference_client.py | 11 ++- tests/test_inference_providers.py | 69 ++++++++++++------- 5 files changed, 79 insertions(+), 55 deletions(-) diff --git a/src/huggingface_hub/inference/_providers/_common.py b/src/huggingface_hub/inference/_providers/_common.py index a512dfa0be..fbbc5896da 100644 --- a/src/huggingface_hub/inference/_providers/_common.py +++ b/src/huggingface_hub/inference/_providers/_common.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from functools import lru_cache -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Literal, Optional, Union from huggingface_hub import constants from huggingface_hub.inference._common import RequestParameters @@ -10,14 +10,24 @@ logger = logging.get_logger(__name__) +@dataclass +class ProviderMappingInfo: + hf_model_id: str + provider_id: str + task: str + status: Literal["live", "staging"] + adapter: Optional[str] = None + adapter_weights_path: Optional[str] = None + + # Dev purposes only. # If you want to try to run inference for a new model locally before it's registered on huggingface.co # for a given Inference Provider, you can add it to the following dictionary. -HARDCODED_MODEL_ID_MAPPING: Dict[str, Dict[str, str]] = { - # "HF model ID" => "Model ID on Inference Provider's side" +HARDCODED_MODEL_INFERENCE_MAPPING: Dict[str, Dict[str, ProviderMappingInfo]] = { + # "HF model ID" => ProviderMappingInfo object initialized with "Model ID on Inference Provider's side" # # Example: - # "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct", + # "Qwen/Qwen2.5-Coder-32B-Instruct": ProviderMappingInfo(provider_id="Qwen2.5-Coder-32B-Instruct") "cerebras": {}, "cohere": {}, "fal-ai": {}, @@ -35,13 +45,6 @@ def filter_none(d: Dict[str, Any]) -> Dict[str, Any]: return {k: v for k, v in d.items() if v is not None} -@dataclass -class ProviderMappingInfo: - provider_id: str - adapter_weights_path: Optional[str] = None - hf_model_id: Optional[str] = None - - class TaskProviderHelper: """Base class for task-specific provider helpers.""" @@ -69,7 +72,7 @@ def prepare_request( api_key = self._prepare_api_key(api_key) # mapped model from HF model ID - provider_mapping_info = self._prepare_mapped_model(model) + provider_mapping_info = self._prepare_mapping_info(model) # default HF headers + user headers (to customize in subclasses) headers = self._prepare_headers(headers, api_key) @@ -117,7 +120,7 @@ def _prepare_api_key(self, api_key: Optional[str]) -> str: ) return api_key - def _prepare_mapped_model(self, model: Optional[str]) -> ProviderMappingInfo: + def _prepare_mapping_info(self, model: Optional[str]) -> ProviderMappingInfo: """Return the mapped model ID to use for the request. Usually not overwritten in subclasses.""" @@ -125,10 +128,8 @@ def _prepare_mapped_model(self, model: Optional[str]) -> ProviderMappingInfo: raise ValueError(f"Please provide an HF model ID supported by {self.provider}.") # hardcoded mapping for local testing - if HARDCODED_MODEL_ID_MAPPING.get(self.provider, {}).get(model): - return ProviderMappingInfo( - provider_id=HARDCODED_MODEL_ID_MAPPING[self.provider][model], - ) + if HARDCODED_MODEL_INFERENCE_MAPPING.get(self.provider, {}).get(model): + return HARDCODED_MODEL_INFERENCE_MAPPING[self.provider][model] provider_mapping = _fetch_inference_provider_mapping(model).get(self.provider) if provider_mapping is None: @@ -150,9 +151,16 @@ def _prepare_mapped_model(self, model: Optional[str]) -> ProviderMappingInfo: adapter_weights_path=adapter_weights_path, provider_id=provider_mapping.provider_id, hf_model_id=model, + task=provider_mapping.task, + status=provider_mapping.status, ) - return ProviderMappingInfo(provider_id=provider_mapping.provider_id) + return ProviderMappingInfo( + provider_id=provider_mapping.provider_id, + hf_model_id=model, + task=provider_mapping.task, + status=provider_mapping.status, + ) def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: """Return the headers to use for the request. diff --git a/src/huggingface_hub/inference/_providers/hf_inference.py b/src/huggingface_hub/inference/_providers/hf_inference.py index 8599ab4567..f3092e7a12 100644 --- a/src/huggingface_hub/inference/_providers/hf_inference.py +++ b/src/huggingface_hub/inference/_providers/hf_inference.py @@ -23,9 +23,9 @@ def _prepare_api_key(self, api_key: Optional[str]) -> str: # special case: for HF Inference we allow not providing an API key return api_key or get_token() # type: ignore[return-value] - def _prepare_mapped_model(self, model: Optional[str]) -> ProviderMappingInfo: + def _prepare_mapping_info(self, model: Optional[str]) -> ProviderMappingInfo: if model is not None and model.startswith(("http://", "https://")): - return ProviderMappingInfo(provider_id=model) + return ProviderMappingInfo(provider_id=model, hf_model_id=model, task=self.task, status="live") model_id = model if model is not None else _fetch_recommended_models().get(self.task) if model_id is None: raise ValueError( @@ -33,7 +33,7 @@ def _prepare_mapped_model(self, model: Optional[str]) -> ProviderMappingInfo: " explicitly. Visit https://huggingface.co/tasks for more info." ) _check_supported_task(model_id, self.task) - return ProviderMappingInfo(provider_id=model_id) + return ProviderMappingInfo(provider_id=model_id, hf_model_id=model_id, task=self.task, status="live") def _prepare_url(self, api_key: str, mapped_model: str) -> str: # hf-inference provider can handle URLs (e.g. Inference Endpoints or TGI deployment) diff --git a/src/huggingface_hub/inference/_providers/openai.py b/src/huggingface_hub/inference/_providers/openai.py index 912303f26a..50f9f67f21 100644 --- a/src/huggingface_hub/inference/_providers/openai.py +++ b/src/huggingface_hub/inference/_providers/openai.py @@ -16,7 +16,7 @@ def _prepare_api_key(self, api_key: Optional[str]) -> str: ) return api_key - def _prepare_mapped_model(self, model: Optional[str]) -> ProviderMappingInfo: + def _prepare_mapping_info(self, model: Optional[str]) -> ProviderMappingInfo: if model is None: raise ValueError("Please provide an OpenAI model ID, e.g. `gpt-4o` or `o1`.") - return ProviderMappingInfo(provider_id=model) + return ProviderMappingInfo(provider_id=model, task="conversational", status="live", hf_model_id=model) diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index 7ab6b388fd..62fc53cca1 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -47,10 +47,7 @@ ) from huggingface_hub.errors import HfHubHTTPError, ValidationError from huggingface_hub.inference._client import _open_as_binary -from huggingface_hub.inference._common import ( - _stream_chat_completion_response, - _stream_text_generation_response, -) +from huggingface_hub.inference._common import _stream_chat_completion_response, _stream_text_generation_response from huggingface_hub.inference._providers import get_provider_helper from huggingface_hub.inference._providers.hf_inference import _build_chat_completion_url @@ -507,14 +504,14 @@ def test_fill_mask(self, client: InferenceClient): def test_hf_inference_get_recommended_model_has_recommendation(self) -> None: from huggingface_hub.inference._providers.hf_inference import HFInferenceTask - HFInferenceTask("feature-extraction")._prepare_mapped_model(None) == "facebook/bart-base" - HFInferenceTask("translation")._prepare_mapped_model(None) == "t5-small" + HFInferenceTask("feature-extraction")._prepare_mapping_info(None).provider_id == "facebook/bart-base" + HFInferenceTask("translation")._prepare_mapping_info(None).provider_id == "t5-small" def test_hf_inference_get_recommended_model_no_recommendation(self) -> None: from huggingface_hub.inference._providers.hf_inference import HFInferenceTask with pytest.raises(ValueError): - HFInferenceTask("text-generation")._prepare_mapped_model(None) + HFInferenceTask("text-generation")._prepare_mapping_info(None) @pytest.mark.parametrize("client", list_clients("image-classification"), indirect=True) def test_image_classification(self, client: InferenceClient): diff --git a/tests/test_inference_providers.py b/tests/test_inference_providers.py index 67a3033840..aa2040d5a6 100644 --- a/tests/test_inference_providers.py +++ b/tests/test_inference_providers.py @@ -28,21 +28,13 @@ HFInferenceConversational, HFInferenceTask, ) -from huggingface_hub.inference._providers.hyperbolic import ( - HyperbolicTextGenerationTask, - HyperbolicTextToImageTask, -) +from huggingface_hub.inference._providers.hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask from huggingface_hub.inference._providers.nebius import NebiusTextToImageTask -from huggingface_hub.inference._providers.novita import ( - NovitaConversationalTask, - NovitaTextGenerationTask, -) +from huggingface_hub.inference._providers.novita import NovitaConversationalTask, NovitaTextGenerationTask from huggingface_hub.inference._providers.openai import OpenAIConversationalTask from huggingface_hub.inference._providers.replicate import ReplicateTask, ReplicateTextToSpeechTask from huggingface_hub.inference._providers.sambanova import SambanovaConversationalTask -from huggingface_hub.inference._providers.together import ( - TogetherTextToImageTask, -) +from huggingface_hub.inference._providers.together import TogetherTextToImageTask from .testing_utils import assert_in_logs @@ -65,13 +57,13 @@ def test_api_key_missing(self): with pytest.raises(ValueError, match="You must provide an api_key.*"): helper._prepare_api_key(None) - def test_prepare_mapped_model(self, mocker, caplog: LogCaptureFixture): + def test_prepare_mapping_info(self, mocker, caplog: LogCaptureFixture): helper = TaskProviderHelper(provider="provider-name", base_url="https://api.provider.com", task="task-name") caplog.set_level(logging.INFO) # Test missing model with pytest.raises(ValueError, match="Please provide an HF model ID.*"): - helper._prepare_mapped_model(None) + helper._prepare_mapping_info(None) # Test unsupported model mocker.patch( @@ -79,22 +71,29 @@ def test_prepare_mapped_model(self, mocker, caplog: LogCaptureFixture): return_value={"other-provider": "mapping"}, ) with pytest.raises(ValueError, match="Model test-model is not supported.*"): - helper._prepare_mapped_model("test-model") + helper._prepare_mapping_info("test-model") # Test task mismatch mocker.patch( "huggingface_hub.inference._providers._common._fetch_inference_provider_mapping", - return_value={"provider-name": mocker.Mock(task="other-task", provider_id="mapped-id", status="active")}, + return_value={ + "provider-name": mocker.Mock( + task="other-task", + provider_id="mapped-id", + status="live", + ) + }, ) with pytest.raises(ValueError, match="Model test-model is not supported for task.*"): - helper._prepare_mapped_model("test-model") + helper._prepare_mapping_info("test-model") # Test staging model mocker.patch( "huggingface_hub.inference._providers._common._fetch_inference_provider_mapping", return_value={"provider-name": mocker.Mock(task="task-name", provider_id="mapped-id", status="staging")}, ) - assert helper._prepare_mapped_model("test-model") == "mapped-id" + assert helper._prepare_mapping_info("test-model").provider_id == "mapped-id" + assert_in_logs( caplog, "Model test-model is in staging mode for provider provider-name. Meant for test purposes only." ) @@ -103,11 +102,31 @@ def test_prepare_mapped_model(self, mocker, caplog: LogCaptureFixture): caplog.clear() mocker.patch( "huggingface_hub.inference._providers._common._fetch_inference_provider_mapping", - return_value={"provider-name": mocker.Mock(task="task-name", provider_id="mapped-id", status="active")}, + return_value={"provider-name": mocker.Mock(task="task-name", provider_id="mapped-id", status="live")}, ) - assert helper._prepare_mapped_model("test-model") == "mapped-id" + assert helper._prepare_mapping_info("test-model").provider_id == "mapped-id" + assert helper._prepare_mapping_info("test-model").hf_model_id == "test-model" + assert helper._prepare_mapping_info("test-model").task == "task-name" + assert helper._prepare_mapping_info("test-model").status == "live" assert len(caplog.records) == 0 + # Test with loras + mocker.patch( + "huggingface_hub.inference._providers._common._fetch_inference_provider_mapping", + return_value={ + "provider-name": mocker.Mock(task="task-name", provider_id="mapped-id", status="live", adapter="lora") + }, + ) + mocker.patch( + "huggingface_hub.inference._providers._common._fetch_lora_weights_path", + return_value="lora-weights-path", + ) + assert helper._prepare_mapping_info("test-model").adapter_weights_path == "lora-weights-path" + assert helper._prepare_mapping_info("test-model").provider_id == "mapped-id" + assert helper._prepare_mapping_info("test-model").hf_model_id == "test-model" + assert helper._prepare_mapping_info("test-model").task == "task-name" + assert helper._prepare_mapping_info("test-model").status == "live" + def test_prepare_headers(self): helper = TaskProviderHelper(provider="provider-name", base_url="https://api.provider.com", task="task-name") headers = helper._prepare_headers({"custom": "header"}, "api_key") @@ -353,7 +372,7 @@ def test_prepare_payload_as_dict(self): class TestHFInferenceProvider: - def test_prepare_mapped_model(self, mocker): + def test_prepare_mapping_info(self, mocker): helper = HFInferenceTask("text-classification") mocker.patch( "huggingface_hub.inference._providers.hf_inference._check_supported_task", @@ -363,13 +382,13 @@ def test_prepare_mapped_model(self, mocker): "huggingface_hub.inference._providers.hf_inference._fetch_recommended_models", return_value={"text-classification": "username/repo_name"}, ) - assert helper._prepare_mapped_model("username/repo_name") == "username/repo_name" - assert helper._prepare_mapped_model(None) == "username/repo_name" - assert helper._prepare_mapped_model("https://any-url.com") == "https://any-url.com" + assert helper._prepare_mapping_info("username/repo_name").provider_id == "username/repo_name" + assert helper._prepare_mapping_info(None).provider_id == "username/repo_name" + assert helper._prepare_mapping_info("https://any-url.com").provider_id == "https://any-url.com" - def test_prepare_mapped_model_unknown_task(self): + def test_prepare_mapping_info_unknown_task(self): with pytest.raises(ValueError, match="Task unknown-task has no recommended model for HF Inference."): - HFInferenceTask("unknown-task")._prepare_mapped_model(None) + HFInferenceTask("unknown-task")._prepare_mapping_info(None) def test_prepare_url(self): helper = HFInferenceTask("text-classification") From 422f5c5bd9e2c1671b2a7dd61830b72141e3b682 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Tue, 15 Apr 2025 17:46:52 +0200 Subject: [PATCH 04/11] update inference provider mapping object --- src/huggingface_hub/hf_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index a994915244..493847e2d9 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -711,11 +711,13 @@ class InferenceProviderMapping: status: Literal["live", "staging"] provider_id: str task: str + adapter: Optional[str] = None def __init__(self, **kwargs): self.status = kwargs.pop("status") self.provider_id = kwargs.pop("providerId") self.task = kwargs.pop("task") + self.adapter = kwargs.pop("adapter", None) self.__dict__.update(**kwargs) From 9558ea803b6bc420ae53e45e2123c9e0b0c34ac3 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Tue, 15 Apr 2025 18:27:06 +0200 Subject: [PATCH 05/11] fix tests --- tests/test_inference_providers.py | 136 +++++++++++++++++++++++++----- 1 file changed, 116 insertions(+), 20 deletions(-) diff --git a/tests/test_inference_providers.py b/tests/test_inference_providers.py index aa2040d5a6..04a5d3e6d2 100644 --- a/tests/test_inference_providers.py +++ b/tests/test_inference_providers.py @@ -10,6 +10,7 @@ from huggingface_hub.inference._providers._common import ( BaseConversationalTask, BaseTextGenerationTask, + ProviderMappingInfo, TaskProviderHelper, recursive_merge, ) @@ -233,7 +234,14 @@ def test_prepare_url(self): def test_prepare_payload_as_dict(self): helper = CohereConversationalTask() payload = helper._prepare_payload_as_dict( - [{"role": "user", "content": "Hello!"}], {}, "CohereForAI/command-r7b-12-2024" + [{"role": "user", "content": "Hello!"}], + {}, + ProviderMappingInfo( + hf_model_id="CohereForAI/command-r7b-12-2024", + provider_id="CohereForAI/command-r7b-12-2024", + task="conversational", + status="live", + ), ) assert payload == { "messages": [{"role": "user", "content": "Hello!"}], @@ -275,7 +283,14 @@ def test_automatic_speech_recognition_response(self): def test_text_to_image_payload(self): helper = FalAITextToImageTask() payload = helper._prepare_payload_as_dict( - "a beautiful cat", {"width": 512, "height": 512}, "username/repo_name" + "a beautiful cat", + {"width": 512, "height": 512}, + ProviderMappingInfo( + hf_model_id="username/repo_name", + provider_id="username/repo_name", + task="text-to-image", + status="live", + ), ) assert payload == { "prompt": "a beautiful cat", @@ -363,7 +378,14 @@ def test_prepare_url(self): def test_prepare_payload_as_dict(self): helper = FireworksAIConversationalTask() payload = helper._prepare_payload_as_dict( - [{"role": "user", "content": "Hello!"}], {}, "meta-llama/Llama-3.1-8B-Instruct" + [{"role": "user", "content": "Hello!"}], + {}, + ProviderMappingInfo( + hf_model_id="meta-llama/Llama-3.1-8B-Instruct", + provider_id="meta-llama/Llama-3.1-8B-Instruct", + task="conversational", + status="live", + ), ) assert payload == { "messages": [{"role": "user", "content": "Hello!"}], @@ -401,25 +423,41 @@ def test_prepare_url(self): def test_prepare_payload_as_dict(self): helper = HFInferenceTask("text-classification") + mapping_info = ProviderMappingInfo( + hf_model_id="username/repo_name", + provider_id="username/repo_name", + task="text-classification", + status="live", + ) assert helper._prepare_payload_as_dict( "dummy text input", parameters={"a": 1, "b": None}, - mapped_model="username/repo_name", + provider_mapping_info=mapping_info, ) == { "inputs": "dummy text input", "parameters": {"a": 1}, } with pytest.raises(ValueError, match="Unexpected binary input for task text-classification."): - helper._prepare_payload_as_dict(b"dummy binary data", {}, "username/repo_name") + helper._prepare_payload_as_dict( + b"dummy binary data", + {}, + mapping_info, + ) def test_prepare_payload_as_bytes(self): helper = HFInferenceBinaryInputTask("image-classification") + mapping_info = ProviderMappingInfo( + hf_model_id="username/repo_name", + provider_id="username/repo_name", + task="image-classification", + status="live", + ) assert ( helper._prepare_payload_as_bytes( b"dummy binary input", parameters={}, - mapped_model="username/repo_name", + provider_mapping_info=mapping_info, extra_payload=None, ) == b"dummy binary input" @@ -429,7 +467,7 @@ def test_prepare_payload_as_bytes(self): helper._prepare_payload_as_bytes( b"dummy binary input", parameters={"a": 1, "b": None}, - mapped_model="username/repo_name", + provider_mapping_info=mapping_info, extra_payload={"extra": "payload"}, ) == b'{"inputs": "ZHVtbXkgYmluYXJ5IGlucHV0", "parameters": {"a": 1}, "extra": "payload"}' @@ -533,11 +571,16 @@ def test_prepare_request_conversational(self, mocker): def test_prepare_payload_as_dict_conversational(self, mapped_model, parameters, expected_model): helper = HFInferenceConversational() messages = [{"role": "user", "content": "Hello!"}] - + provider_mapping_info = ProviderMappingInfo( + hf_model_id=mapped_model, + provider_id=mapped_model, + task="conversational", + status="live", + ) payload = helper._prepare_payload_as_dict( inputs=messages, parameters=parameters, - mapped_model=mapped_model, + provider_mapping_info=provider_mapping_info, ) assert payload["model"] == expected_model @@ -663,7 +706,14 @@ def test_prepare_payload_conversational(self): """Test payload preparation for conversational task.""" helper = HyperbolicTextGenerationTask("conversational") payload = helper._prepare_payload_as_dict( - [{"role": "user", "content": "Hello!"}], {"temperature": 0.7}, "meta-llama/Llama-3.2-3B-Instruct" + [{"role": "user", "content": "Hello!"}], + {"temperature": 0.7}, + ProviderMappingInfo( + hf_model_id="meta-llama/Llama-3.2-3B-Instruct", + provider_id="meta-llama/Llama-3.2-3B-Instruct", + task="conversational", + status="live", + ), ) assert payload == { "messages": [{"role": "user", "content": "Hello!"}], @@ -683,7 +733,12 @@ def test_prepare_payload_text_to_image(self): "height": 512, "seed": 42, }, - "stabilityai/sdxl", + ProviderMappingInfo( + hf_model_id="stabilityai/sdxl-turbo", + provider_id="stabilityai/sdxl", + task="text-to-image", + status="live", + ), ) assert payload == { "prompt": "a beautiful cat", @@ -713,7 +768,12 @@ def test_prepare_payload_as_dict_text_to_image(self): payload = helper._prepare_payload_as_dict( "a beautiful cat", {"num_inference_steps": 10, "width": 512, "height": 512, "guidance_scale": 7.5}, - "black-forest-labs/flux-schnell", + ProviderMappingInfo( + hf_model_id="black-forest-labs/flux-schnell", + provider_id="black-forest-labs/flux-schnell", + task="text-to-image", + status="live", + ), ) assert payload == { "prompt": "a beautiful cat", @@ -771,13 +831,27 @@ def test_prepare_payload_as_dict(self): # No model version payload = helper._prepare_payload_as_dict( - "a beautiful cat", {"num_inference_steps": 20}, "black-forest-labs/FLUX.1-schnell" + "a beautiful cat", + {"num_inference_steps": 20}, + ProviderMappingInfo( + hf_model_id="black-forest-labs/FLUX.1-schnell", + provider_id="black-forest-labs/FLUX.1-schnell", + task="text-to-image", + status="live", + ), ) assert payload == {"input": {"prompt": "a beautiful cat", "num_inference_steps": 20}} # Model with specific version payload = helper._prepare_payload_as_dict( - "a beautiful cat", {"num_inference_steps": 20}, "black-forest-labs/FLUX.1-schnell:1944af04d098ef" + "a beautiful cat", + {"num_inference_steps": 20}, + ProviderMappingInfo( + hf_model_id="black-forest-labs/FLUX.1-schnell", + provider_id="black-forest-labs/FLUX.1-schnell:1944af04d098ef", + task="text-to-image", + status="live", + ), ) assert payload == { "input": {"prompt": "a beautiful cat", "num_inference_steps": 20}, @@ -787,7 +861,14 @@ def test_prepare_payload_as_dict(self): def test_text_to_speech_payload(self): helper = ReplicateTextToSpeechTask() payload = helper._prepare_payload_as_dict( - "Hello world", {}, "hexgrad/Kokoro-82M:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13" + "Hello world", + {}, + ProviderMappingInfo( + hf_model_id="hexgrad/Kokoro-82M", + provider_id="hexgrad/Kokoro-82M:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13", + task="text-to-speech", + status="live", + ), ) assert payload == { "input": {"text": "Hello world"}, @@ -826,7 +907,12 @@ def test_prepare_payload_as_dict_text_to_image(self): payload = helper._prepare_payload_as_dict( "a beautiful cat", {"num_inference_steps": 10, "guidance_scale": 1, "width": 512, "height": 512}, - "black-forest-labs/FLUX.1-schnell", + ProviderMappingInfo( + hf_model_id="black-forest-labs/FLUX.1-schnell", + provider_id="black-forest-labs/FLUX.1-schnell", + task="text-to-image", + status="live", + ), ) assert payload == { "prompt": "a beautiful cat", @@ -858,14 +944,19 @@ def test_prepare_payload(self): payload = helper._prepare_payload_as_dict( inputs=messages, parameters=parameters, - mapped_model="test-model", + provider_mapping_info=ProviderMappingInfo( + hf_model_id="test-model", + provider_id="test-provider-id", + task="conversational", + status="live", + ), ) assert payload == { "messages": messages, "temperature": 0.7, "max_tokens": 100, - "model": "test-model", + "model": "test-provider-id", } @@ -883,14 +974,19 @@ def test_prepare_payload(self): payload = helper._prepare_payload_as_dict( inputs=prompt, parameters=parameters, - mapped_model="test-model", + provider_mapping_info=ProviderMappingInfo( + hf_model_id="test-model", + provider_id="test-provider-id", + task="text-generation", + status="live", + ), ) assert payload == { "prompt": prompt, "temperature": 0.7, "max_tokens": 100, - "model": "test-model", + "model": "test-provider-id", } From 312a21039bb3af0726ec55e54b6e492fbb5ab419 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Thu, 17 Apr 2025 12:23:40 +0200 Subject: [PATCH 06/11] fixes --- .../inference/_providers/_common.py | 26 +++++++++---------- .../inference/_providers/fal_ai.py | 10 +++---- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/src/huggingface_hub/inference/_providers/_common.py b/src/huggingface_hub/inference/_providers/_common.py index fbbc5896da..f767e2346c 100644 --- a/src/huggingface_hub/inference/_providers/_common.py +++ b/src/huggingface_hub/inference/_providers/_common.py @@ -16,6 +16,7 @@ class ProviderMappingInfo: provider_id: str task: str status: Literal["live", "staging"] + adapter: Optional[str] = None adapter_weights_path: Optional[str] = None @@ -27,7 +28,10 @@ class ProviderMappingInfo: # "HF model ID" => ProviderMappingInfo object initialized with "Model ID on Inference Provider's side" # # Example: - # "Qwen/Qwen2.5-Coder-32B-Instruct": ProviderMappingInfo(provider_id="Qwen2.5-Coder-32B-Instruct") + # "Qwen/Qwen2.5-Coder-32B-Instruct": ProviderMappingInfo(hf_model_id="Qwen/Qwen2.5-Coder-32B-Instruct", + # provider_id="Qwen2.5-Coder-32B-Instruct", + # task="conversational", + # status="live") "cerebras": {}, "cohere": {}, "fal-ai": {}, @@ -145,22 +149,16 @@ def _prepare_mapping_info(self, model: Optional[str]) -> ProviderMappingInfo: logger.warning( f"Model {model} is in staging mode for provider {self.provider}. Meant for test purposes only." ) - if provider_mapping.adapter == "lora": - adapter_weights_path = _fetch_lora_weights_path(model) - return ProviderMappingInfo( - adapter_weights_path=adapter_weights_path, - provider_id=provider_mapping.provider_id, - hf_model_id=model, - task=provider_mapping.task, - status=provider_mapping.status, - ) - - return ProviderMappingInfo( - provider_id=provider_mapping.provider_id, + mapping_info = ProviderMappingInfo( hf_model_id=model, - task=provider_mapping.task, + provider_id=provider_mapping.provider_id, status=provider_mapping.status, + task=provider_mapping.task, + adapter=provider_mapping.adapter, ) + if provider_mapping.adapter == "lora": + mapping_info.adapter_weights_path = _fetch_lora_weights_path(model) + return mapping_info def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: """Return the headers to use for the request. diff --git a/src/huggingface_hub/inference/_providers/fal_ai.py b/src/huggingface_hub/inference/_providers/fal_ai.py index 21f48f55e7..3c578e397e 100644 --- a/src/huggingface_hub/inference/_providers/fal_ai.py +++ b/src/huggingface_hub/inference/_providers/fal_ai.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Optional, Union from urllib.parse import urlparse -from huggingface_hub.constants import ENDPOINT +from huggingface_hub import constants from huggingface_hub.inference._common import RequestParameters, _as_dict from huggingface_hub.inference._providers._common import ProviderMappingInfo, TaskProviderHelper, filter_none from huggingface_hub.utils import get_session, hf_raise_for_status @@ -77,10 +77,10 @@ def _prepare_payload_as_dict( "height": payload.pop("height"), } if provider_mapping_info.adapter_weights_path is not None: - lora_path = ( - f"{ENDPOINT}/" - f"{provider_mapping_info.hf_model_id}/resolve/main/" - f"{provider_mapping_info.adapter_weights_path}" + lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format( + repo_id=provider_mapping_info.hf_model_id, + revision="main", + filename=provider_mapping_info.adapter_weights_path, ) payload["loras"] = [{"path": lora_path, "scale": 1}] if provider_mapping_info.provider_id == "fal-ai/lora": From f55e62872bbc6440b2e2a07d560971672346bfaf Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Thu, 24 Apr 2025 22:24:01 +0200 Subject: [PATCH 07/11] use the precomputed adapterWeightsPath property --- src/huggingface_hub/_inference_endpoints.py | 13 ++- src/huggingface_hub/hf_api.py | 7 +- .../inference/_providers/_common.py | 47 +++------- .../inference/_providers/black_forest_labs.py | 5 +- .../inference/_providers/fal_ai.py | 11 ++- .../inference/_providers/hf_inference.py | 21 +++-- .../inference/_providers/hyperbolic.py | 10 +- .../inference/_providers/nebius.py | 4 +- .../inference/_providers/novita.py | 4 +- .../inference/_providers/openai.py | 7 +- .../inference/_providers/replicate.py | 7 +- .../inference/_providers/together.py | 4 +- tests/test_inference_providers.py | 91 +++++++++++-------- 13 files changed, 120 insertions(+), 111 deletions(-) diff --git a/src/huggingface_hub/_inference_endpoints.py b/src/huggingface_hub/_inference_endpoints.py index 37733fef1b..52a31361b4 100644 --- a/src/huggingface_hub/_inference_endpoints.py +++ b/src/huggingface_hub/_inference_endpoints.py @@ -6,14 +6,13 @@ from huggingface_hub.errors import InferenceEndpointError, InferenceEndpointTimeoutError -from .inference._client import InferenceClient -from .inference._generated._async_client import AsyncInferenceClient from .utils import get_session, logging, parse_datetime if TYPE_CHECKING: from .hf_api import HfApi - + from .inference._client import InferenceClient + from .inference._generated._async_client import AsyncInferenceClient logger = logging.get_logger(__name__) @@ -138,7 +137,7 @@ def __post_init__(self) -> None: self._populate_from_raw() @property - def client(self) -> InferenceClient: + def client(self) -> "InferenceClient": """Returns a client to make predictions on this Inference Endpoint. Returns: @@ -152,13 +151,15 @@ def client(self) -> InferenceClient: "Cannot create a client for this Inference Endpoint as it is not yet deployed. " "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again." ) + from .inference._client import InferenceClient + return InferenceClient( model=self.url, token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok. ) @property - def async_client(self) -> AsyncInferenceClient: + def async_client(self) -> "AsyncInferenceClient": """Returns a client to make predictions on this Inference Endpoint. Returns: @@ -172,6 +173,8 @@ def async_client(self) -> AsyncInferenceClient: "Cannot create a client for this Inference Endpoint as it is not yet deployed. " "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again." ) + from .inference._generated._async_client import AsyncInferenceClient + return AsyncInferenceClient( model=self.url, token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok. diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 493847e2d9..a3c006b69f 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -708,16 +708,21 @@ def __init__(self, **kwargs): @dataclass class InferenceProviderMapping: + hf_model_id: str status: Literal["live", "staging"] provider_id: str task: str + adapter: Optional[str] = None + adapter_weights_path: Optional[str] = None def __init__(self, **kwargs): + self.hf_model_id = kwargs.pop("hf_model_id") self.status = kwargs.pop("status") self.provider_id = kwargs.pop("providerId") self.task = kwargs.pop("task") self.adapter = kwargs.pop("adapter", None) + self.adapter_weights_path = kwargs.pop("adapterWeightsPath", None) self.__dict__.update(**kwargs) @@ -849,7 +854,7 @@ def __init__(self, **kwargs): self.inference_provider_mapping = kwargs.pop("inferenceProviderMapping", None) if self.inference_provider_mapping: self.inference_provider_mapping = { - provider: InferenceProviderMapping(**value) + provider: InferenceProviderMapping(**{**value, "hf_model_id": self.id}) for provider, value in self.inference_provider_mapping.items() } diff --git a/src/huggingface_hub/inference/_providers/_common.py b/src/huggingface_hub/inference/_providers/_common.py index f767e2346c..77f1fc97d6 100644 --- a/src/huggingface_hub/inference/_providers/_common.py +++ b/src/huggingface_hub/inference/_providers/_common.py @@ -1,34 +1,22 @@ -from dataclasses import dataclass from functools import lru_cache -from typing import Any, Dict, Literal, Optional, Union +from typing import Any, Dict, Optional, Union from huggingface_hub import constants +from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters from huggingface_hub.utils import build_hf_headers, get_token, logging logger = logging.get_logger(__name__) - -@dataclass -class ProviderMappingInfo: - hf_model_id: str - provider_id: str - task: str - status: Literal["live", "staging"] - - adapter: Optional[str] = None - adapter_weights_path: Optional[str] = None - - # Dev purposes only. # If you want to try to run inference for a new model locally before it's registered on huggingface.co # for a given Inference Provider, you can add it to the following dictionary. -HARDCODED_MODEL_INFERENCE_MAPPING: Dict[str, Dict[str, ProviderMappingInfo]] = { - # "HF model ID" => ProviderMappingInfo object initialized with "Model ID on Inference Provider's side" +HARDCODED_MODEL_INFERENCE_MAPPING: Dict[str, Dict[str, InferenceProviderMapping]] = { + # "HF model ID" => InferenceProviderMapping object initialized with "Model ID on Inference Provider's side" # # Example: - # "Qwen/Qwen2.5-Coder-32B-Instruct": ProviderMappingInfo(hf_model_id="Qwen/Qwen2.5-Coder-32B-Instruct", + # "Qwen/Qwen2.5-Coder-32B-Instruct": InferenceProviderMapping(hf_model_id="Qwen/Qwen2.5-Coder-32B-Instruct", # provider_id="Qwen2.5-Coder-32B-Instruct", # task="conversational", # status="live") @@ -124,7 +112,7 @@ def _prepare_api_key(self, api_key: Optional[str]) -> str: ) return api_key - def _prepare_mapping_info(self, model: Optional[str]) -> ProviderMappingInfo: + def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping: """Return the mapped model ID to use for the request. Usually not overwritten in subclasses.""" @@ -149,16 +137,7 @@ def _prepare_mapping_info(self, model: Optional[str]) -> ProviderMappingInfo: logger.warning( f"Model {model} is in staging mode for provider {self.provider}. Meant for test purposes only." ) - mapping_info = ProviderMappingInfo( - hf_model_id=model, - provider_id=provider_mapping.provider_id, - status=provider_mapping.status, - task=provider_mapping.task, - adapter=provider_mapping.adapter, - ) - if provider_mapping.adapter == "lora": - mapping_info.adapter_weights_path = _fetch_lora_weights_path(model) - return mapping_info + return provider_mapping def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: """Return the headers to use for the request. @@ -195,7 +174,7 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: """Return the payload to use for the request, as a dict. @@ -205,7 +184,11 @@ def _prepare_payload_as_dict( return None def _prepare_payload_as_bytes( - self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo, extra_payload: Optional[Dict] + self, + inputs: Any, + parameters: Dict, + provider_mapping_info: InferenceProviderMapping, + extra_payload: Optional[Dict], ) -> Optional[bytes]: """Return the body to use for the request, as bytes. @@ -228,7 +211,7 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/chat/completions" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: return {"messages": inputs, **filter_none(parameters), "model": provider_mapping_info.provider_id} @@ -246,7 +229,7 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/completions" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: return {"prompt": inputs, **filter_none(parameters), "model": provider_mapping_info.provider_id} diff --git a/src/huggingface_hub/inference/_providers/black_forest_labs.py b/src/huggingface_hub/inference/_providers/black_forest_labs.py index bdbfff8f1b..afa8ed281d 100644 --- a/src/huggingface_hub/inference/_providers/black_forest_labs.py +++ b/src/huggingface_hub/inference/_providers/black_forest_labs.py @@ -1,8 +1,9 @@ import time from typing import Any, Dict, Optional, Union +from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict -from huggingface_hub.inference._providers._common import ProviderMappingInfo, TaskProviderHelper, filter_none +from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none from huggingface_hub.utils import logging from huggingface_hub.utils._http import get_session @@ -28,7 +29,7 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return f"/v1/{mapped_model}" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: parameters = filter_none(parameters) if "num_inference_steps" in parameters: diff --git a/src/huggingface_hub/inference/_providers/fal_ai.py b/src/huggingface_hub/inference/_providers/fal_ai.py index 3c578e397e..b117fcb201 100644 --- a/src/huggingface_hub/inference/_providers/fal_ai.py +++ b/src/huggingface_hub/inference/_providers/fal_ai.py @@ -5,8 +5,9 @@ from urllib.parse import urlparse from huggingface_hub import constants +from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict -from huggingface_hub.inference._providers._common import ProviderMappingInfo, TaskProviderHelper, filter_none +from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none from huggingface_hub.utils import get_session, hf_raise_for_status from huggingface_hub.utils.logging import get_logger @@ -36,7 +37,7 @@ def __init__(self): super().__init__("automatic-speech-recognition") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: if isinstance(inputs, str) and inputs.startswith(("http://", "https://")): # If input is a URL, pass it directly @@ -65,7 +66,7 @@ def __init__(self): super().__init__("text-to-image") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: payload: Dict[str, Any] = { "prompt": inputs, @@ -98,7 +99,7 @@ def __init__(self): super().__init__("text-to-speech") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: return {"lyrics": inputs, **filter_none(parameters)} @@ -125,7 +126,7 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return f"/{mapped_model}" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: return {"prompt": inputs, **filter_none(parameters)} diff --git a/src/huggingface_hub/inference/_providers/hf_inference.py b/src/huggingface_hub/inference/_providers/hf_inference.py index f3092e7a12..b949f6596c 100644 --- a/src/huggingface_hub/inference/_providers/hf_inference.py +++ b/src/huggingface_hub/inference/_providers/hf_inference.py @@ -4,8 +4,9 @@ from typing import Any, Dict, Optional from huggingface_hub import constants +from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import _b64_encode, _open_as_binary -from huggingface_hub.inference._providers._common import ProviderMappingInfo, TaskProviderHelper, filter_none +from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none from huggingface_hub.utils import build_hf_headers, get_session, get_token, hf_raise_for_status @@ -23,9 +24,9 @@ def _prepare_api_key(self, api_key: Optional[str]) -> str: # special case: for HF Inference we allow not providing an API key return api_key or get_token() # type: ignore[return-value] - def _prepare_mapping_info(self, model: Optional[str]) -> ProviderMappingInfo: + def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping: if model is not None and model.startswith(("http://", "https://")): - return ProviderMappingInfo(provider_id=model, hf_model_id=model, task=self.task, status="live") + return InferenceProviderMapping(providerId=model, hf_model_id=model, task=self.task, status="live") model_id = model if model is not None else _fetch_recommended_models().get(self.task) if model_id is None: raise ValueError( @@ -33,7 +34,7 @@ def _prepare_mapping_info(self, model: Optional[str]) -> ProviderMappingInfo: " explicitly. Visit https://huggingface.co/tasks for more info." ) _check_supported_task(model_id, self.task) - return ProviderMappingInfo(provider_id=model_id, hf_model_id=model_id, task=self.task, status="live") + return InferenceProviderMapping(providerId=model_id, hf_model_id=model_id, task=self.task, status="live") def _prepare_url(self, api_key: str, mapped_model: str) -> str: # hf-inference provider can handle URLs (e.g. Inference Endpoints or TGI deployment) @@ -48,7 +49,7 @@ def _prepare_url(self, api_key: str, mapped_model: str) -> str: ) def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: if isinstance(inputs, bytes): raise ValueError(f"Unexpected binary input for task {self.task}.") @@ -59,12 +60,16 @@ def _prepare_payload_as_dict( class HFInferenceBinaryInputTask(HFInferenceTask): def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: return None def _prepare_payload_as_bytes( - self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo, extra_payload: Optional[Dict] + self, + inputs: Any, + parameters: Dict, + provider_mapping_info: InferenceProviderMapping, + extra_payload: Optional[Dict], ) -> Optional[bytes]: parameters = filter_none({k: v for k, v in parameters.items() if v is not None}) extra_payload = extra_payload or {} @@ -89,7 +94,7 @@ def __init__(self): super().__init__("conversational") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: mapped_model = provider_mapping_info.provider_id payload_model = parameters.get("model") or mapped_model diff --git a/src/huggingface_hub/inference/_providers/hyperbolic.py b/src/huggingface_hub/inference/_providers/hyperbolic.py index 7bb970b9c7..6dcb14cc27 100644 --- a/src/huggingface_hub/inference/_providers/hyperbolic.py +++ b/src/huggingface_hub/inference/_providers/hyperbolic.py @@ -1,13 +1,9 @@ import base64 from typing import Any, Dict, Optional, Union +from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict -from huggingface_hub.inference._providers._common import ( - BaseConversationalTask, - ProviderMappingInfo, - TaskProviderHelper, - filter_none, -) +from huggingface_hub.inference._providers._common import BaseConversationalTask, TaskProviderHelper, filter_none class HyperbolicTextToImageTask(TaskProviderHelper): @@ -18,7 +14,7 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/images/generations" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: mapped_model = provider_mapping_info.provider_id parameters = filter_none(parameters) diff --git a/src/huggingface_hub/inference/_providers/nebius.py b/src/huggingface_hub/inference/_providers/nebius.py index 51a51f72fb..8593872a81 100644 --- a/src/huggingface_hub/inference/_providers/nebius.py +++ b/src/huggingface_hub/inference/_providers/nebius.py @@ -1,11 +1,11 @@ import base64 from typing import Any, Dict, Optional, Union +from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict from huggingface_hub.inference._providers._common import ( BaseConversationalTask, BaseTextGenerationTask, - ProviderMappingInfo, TaskProviderHelper, filter_none, ) @@ -39,7 +39,7 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/images/generations" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: mapped_model = provider_mapping_info.provider_id parameters = filter_none(parameters) diff --git a/src/huggingface_hub/inference/_providers/novita.py b/src/huggingface_hub/inference/_providers/novita.py index f8134b9d6c..44adc9017b 100644 --- a/src/huggingface_hub/inference/_providers/novita.py +++ b/src/huggingface_hub/inference/_providers/novita.py @@ -1,10 +1,10 @@ from typing import Any, Dict, Optional, Union +from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict from huggingface_hub.inference._providers._common import ( BaseConversationalTask, BaseTextGenerationTask, - ProviderMappingInfo, TaskProviderHelper, filter_none, ) @@ -51,7 +51,7 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return f"/v3/hf/{mapped_model}" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: return {"prompt": inputs, **filter_none(parameters)} diff --git a/src/huggingface_hub/inference/_providers/openai.py b/src/huggingface_hub/inference/_providers/openai.py index 50f9f67f21..4ea95f1643 100644 --- a/src/huggingface_hub/inference/_providers/openai.py +++ b/src/huggingface_hub/inference/_providers/openai.py @@ -1,6 +1,7 @@ from typing import Optional -from huggingface_hub.inference._providers._common import BaseConversationalTask, ProviderMappingInfo +from huggingface_hub.hf_api import InferenceProviderMapping +from huggingface_hub.inference._providers._common import BaseConversationalTask class OpenAIConversationalTask(BaseConversationalTask): @@ -16,7 +17,7 @@ def _prepare_api_key(self, api_key: Optional[str]) -> str: ) return api_key - def _prepare_mapping_info(self, model: Optional[str]) -> ProviderMappingInfo: + def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping: if model is None: raise ValueError("Please provide an OpenAI model ID, e.g. `gpt-4o` or `o1`.") - return ProviderMappingInfo(provider_id=model, task="conversational", status="live", hf_model_id=model) + return InferenceProviderMapping(providerId=model, task="conversational", status="live", hf_model_id=model) diff --git a/src/huggingface_hub/inference/_providers/replicate.py b/src/huggingface_hub/inference/_providers/replicate.py index 115ed5a779..d76eaa2b5a 100644 --- a/src/huggingface_hub/inference/_providers/replicate.py +++ b/src/huggingface_hub/inference/_providers/replicate.py @@ -1,7 +1,8 @@ from typing import Any, Dict, Optional, Union +from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict -from huggingface_hub.inference._providers._common import ProviderMappingInfo, TaskProviderHelper, filter_none +from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none from huggingface_hub.utils import get_session @@ -24,7 +25,7 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return f"/v1/models/{mapped_model}/predictions" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: mapped_model = provider_mapping_info.provider_id payload: Dict[str, Any] = {"input": {"prompt": inputs, **filter_none(parameters)}} @@ -51,7 +52,7 @@ def __init__(self): super().__init__("text-to-speech") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore[assignment] payload["input"]["text"] = payload["input"].pop("prompt") # rename "prompt" to "text" for TTS diff --git a/src/huggingface_hub/inference/_providers/together.py b/src/huggingface_hub/inference/_providers/together.py index 3de2b36237..b27e332938 100644 --- a/src/huggingface_hub/inference/_providers/together.py +++ b/src/huggingface_hub/inference/_providers/together.py @@ -2,11 +2,11 @@ from abc import ABC from typing import Any, Dict, Optional, Union +from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict from huggingface_hub.inference._providers._common import ( BaseConversationalTask, BaseTextGenerationTask, - ProviderMappingInfo, TaskProviderHelper, filter_none, ) @@ -57,7 +57,7 @@ def __init__(self): super().__init__("text-to-image") def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: ProviderMappingInfo + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: mapped_model = provider_mapping_info.provider_id parameters = filter_none(parameters) diff --git a/tests/test_inference_providers.py b/tests/test_inference_providers.py index 04a5d3e6d2..68425013aa 100644 --- a/tests/test_inference_providers.py +++ b/tests/test_inference_providers.py @@ -6,11 +6,11 @@ import pytest from pytest import LogCaptureFixture +from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters from huggingface_hub.inference._providers._common import ( BaseConversationalTask, BaseTextGenerationTask, - ProviderMappingInfo, TaskProviderHelper, recursive_merge, ) @@ -80,7 +80,8 @@ def test_prepare_mapping_info(self, mocker, caplog: LogCaptureFixture): return_value={ "provider-name": mocker.Mock( task="other-task", - provider_id="mapped-id", + providerId="mapped-id", + hf_model_id="test-model", status="live", ) }, @@ -91,7 +92,11 @@ def test_prepare_mapping_info(self, mocker, caplog: LogCaptureFixture): # Test staging model mocker.patch( "huggingface_hub.inference._providers._common._fetch_inference_provider_mapping", - return_value={"provider-name": mocker.Mock(task="task-name", provider_id="mapped-id", status="staging")}, + return_value={ + "provider-name": mocker.Mock( + task="task-name", hf_model_id="test-model", provider_id="mapped-id", status="staging" + ) + }, ) assert helper._prepare_mapping_info("test-model").provider_id == "mapped-id" @@ -103,7 +108,11 @@ def test_prepare_mapping_info(self, mocker, caplog: LogCaptureFixture): caplog.clear() mocker.patch( "huggingface_hub.inference._providers._common._fetch_inference_provider_mapping", - return_value={"provider-name": mocker.Mock(task="task-name", provider_id="mapped-id", status="live")}, + return_value={ + "provider-name": mocker.Mock( + task="task-name", hf_model_id="test-model", provider_id="mapped-id", status="live" + ) + }, ) assert helper._prepare_mapping_info("test-model").provider_id == "mapped-id" assert helper._prepare_mapping_info("test-model").hf_model_id == "test-model" @@ -115,13 +124,17 @@ def test_prepare_mapping_info(self, mocker, caplog: LogCaptureFixture): mocker.patch( "huggingface_hub.inference._providers._common._fetch_inference_provider_mapping", return_value={ - "provider-name": mocker.Mock(task="task-name", provider_id="mapped-id", status="live", adapter="lora") + "provider-name": mocker.Mock( + task="task-name", + hf_model_id="test-model", + provider_id="mapped-id", + status="live", + adapter_weights_path="lora-weights-path", + adapter="lora", + ) }, ) - mocker.patch( - "huggingface_hub.inference._providers._common._fetch_lora_weights_path", - return_value="lora-weights-path", - ) + assert helper._prepare_mapping_info("test-model").adapter_weights_path == "lora-weights-path" assert helper._prepare_mapping_info("test-model").provider_id == "mapped-id" assert helper._prepare_mapping_info("test-model").hf_model_id == "test-model" @@ -236,9 +249,9 @@ def test_prepare_payload_as_dict(self): payload = helper._prepare_payload_as_dict( [{"role": "user", "content": "Hello!"}], {}, - ProviderMappingInfo( + InferenceProviderMapping( hf_model_id="CohereForAI/command-r7b-12-2024", - provider_id="CohereForAI/command-r7b-12-2024", + providerId="CohereForAI/command-r7b-12-2024", task="conversational", status="live", ), @@ -285,9 +298,9 @@ def test_text_to_image_payload(self): payload = helper._prepare_payload_as_dict( "a beautiful cat", {"width": 512, "height": 512}, - ProviderMappingInfo( + InferenceProviderMapping( hf_model_id="username/repo_name", - provider_id="username/repo_name", + providerId="username/repo_name", task="text-to-image", status="live", ), @@ -380,9 +393,9 @@ def test_prepare_payload_as_dict(self): payload = helper._prepare_payload_as_dict( [{"role": "user", "content": "Hello!"}], {}, - ProviderMappingInfo( + InferenceProviderMapping( hf_model_id="meta-llama/Llama-3.1-8B-Instruct", - provider_id="meta-llama/Llama-3.1-8B-Instruct", + providerId="meta-llama/Llama-3.1-8B-Instruct", task="conversational", status="live", ), @@ -423,9 +436,9 @@ def test_prepare_url(self): def test_prepare_payload_as_dict(self): helper = HFInferenceTask("text-classification") - mapping_info = ProviderMappingInfo( + mapping_info = InferenceProviderMapping( hf_model_id="username/repo_name", - provider_id="username/repo_name", + providerId="username/repo_name", task="text-classification", status="live", ) @@ -447,9 +460,9 @@ def test_prepare_payload_as_dict(self): def test_prepare_payload_as_bytes(self): helper = HFInferenceBinaryInputTask("image-classification") - mapping_info = ProviderMappingInfo( + mapping_info = InferenceProviderMapping( hf_model_id="username/repo_name", - provider_id="username/repo_name", + providerId="username/repo_name", task="image-classification", status="live", ) @@ -571,9 +584,9 @@ def test_prepare_request_conversational(self, mocker): def test_prepare_payload_as_dict_conversational(self, mapped_model, parameters, expected_model): helper = HFInferenceConversational() messages = [{"role": "user", "content": "Hello!"}] - provider_mapping_info = ProviderMappingInfo( + provider_mapping_info = InferenceProviderMapping( hf_model_id=mapped_model, - provider_id=mapped_model, + providerId=mapped_model, task="conversational", status="live", ) @@ -708,9 +721,9 @@ def test_prepare_payload_conversational(self): payload = helper._prepare_payload_as_dict( [{"role": "user", "content": "Hello!"}], {"temperature": 0.7}, - ProviderMappingInfo( + InferenceProviderMapping( hf_model_id="meta-llama/Llama-3.2-3B-Instruct", - provider_id="meta-llama/Llama-3.2-3B-Instruct", + providerId="meta-llama/Llama-3.2-3B-Instruct", task="conversational", status="live", ), @@ -733,9 +746,9 @@ def test_prepare_payload_text_to_image(self): "height": 512, "seed": 42, }, - ProviderMappingInfo( + InferenceProviderMapping( hf_model_id="stabilityai/sdxl-turbo", - provider_id="stabilityai/sdxl", + providerId="stabilityai/sdxl", task="text-to-image", status="live", ), @@ -768,9 +781,9 @@ def test_prepare_payload_as_dict_text_to_image(self): payload = helper._prepare_payload_as_dict( "a beautiful cat", {"num_inference_steps": 10, "width": 512, "height": 512, "guidance_scale": 7.5}, - ProviderMappingInfo( + InferenceProviderMapping( hf_model_id="black-forest-labs/flux-schnell", - provider_id="black-forest-labs/flux-schnell", + providerId="black-forest-labs/flux-schnell", task="text-to-image", status="live", ), @@ -833,9 +846,9 @@ def test_prepare_payload_as_dict(self): payload = helper._prepare_payload_as_dict( "a beautiful cat", {"num_inference_steps": 20}, - ProviderMappingInfo( + InferenceProviderMapping( hf_model_id="black-forest-labs/FLUX.1-schnell", - provider_id="black-forest-labs/FLUX.1-schnell", + providerId="black-forest-labs/FLUX.1-schnell", task="text-to-image", status="live", ), @@ -846,9 +859,9 @@ def test_prepare_payload_as_dict(self): payload = helper._prepare_payload_as_dict( "a beautiful cat", {"num_inference_steps": 20}, - ProviderMappingInfo( + InferenceProviderMapping( hf_model_id="black-forest-labs/FLUX.1-schnell", - provider_id="black-forest-labs/FLUX.1-schnell:1944af04d098ef", + providerId="black-forest-labs/FLUX.1-schnell:1944af04d098ef", task="text-to-image", status="live", ), @@ -863,9 +876,9 @@ def test_text_to_speech_payload(self): payload = helper._prepare_payload_as_dict( "Hello world", {}, - ProviderMappingInfo( + InferenceProviderMapping( hf_model_id="hexgrad/Kokoro-82M", - provider_id="hexgrad/Kokoro-82M:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13", + providerId="hexgrad/Kokoro-82M:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13", task="text-to-speech", status="live", ), @@ -907,9 +920,9 @@ def test_prepare_payload_as_dict_text_to_image(self): payload = helper._prepare_payload_as_dict( "a beautiful cat", {"num_inference_steps": 10, "guidance_scale": 1, "width": 512, "height": 512}, - ProviderMappingInfo( + InferenceProviderMapping( hf_model_id="black-forest-labs/FLUX.1-schnell", - provider_id="black-forest-labs/FLUX.1-schnell", + providerId="black-forest-labs/FLUX.1-schnell", task="text-to-image", status="live", ), @@ -944,9 +957,9 @@ def test_prepare_payload(self): payload = helper._prepare_payload_as_dict( inputs=messages, parameters=parameters, - provider_mapping_info=ProviderMappingInfo( + provider_mapping_info=InferenceProviderMapping( hf_model_id="test-model", - provider_id="test-provider-id", + providerId="test-provider-id", task="conversational", status="live", ), @@ -974,9 +987,9 @@ def test_prepare_payload(self): payload = helper._prepare_payload_as_dict( inputs=prompt, parameters=parameters, - provider_mapping_info=ProviderMappingInfo( + provider_mapping_info=InferenceProviderMapping( hf_model_id="test-model", - provider_id="test-provider-id", + providerId="test-provider-id", task="text-generation", status="live", ), From b59a1234614db457524a5d4080aa077201f179cb Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Fri, 25 Apr 2025 15:49:36 +0200 Subject: [PATCH 08/11] remove unnecessary function --- src/huggingface_hub/inference/_providers/_common.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/huggingface_hub/inference/_providers/_common.py b/src/huggingface_hub/inference/_providers/_common.py index 77f1fc97d6..afea6fe29f 100644 --- a/src/huggingface_hub/inference/_providers/_common.py +++ b/src/huggingface_hub/inference/_providers/_common.py @@ -248,17 +248,6 @@ def _fetch_inference_provider_mapping(model: str) -> Dict: return provider_mapping -@lru_cache(maxsize=None) -def _fetch_lora_weights_path(model: str) -> str: - from huggingface_hub.hf_api import HfApi - - repo_files = HfApi().list_repo_files(model) - safetensors_files = [f for f in repo_files if f.endswith(".safetensors")] - if len(safetensors_files) != 1: - raise ValueError(f"Expected exactly one safetensors file in repo {model}, got {len(safetensors_files)}.") - return safetensors_files[0] - - def recursive_merge(dict1: Dict, dict2: Dict) -> Dict: return { **dict1, From b98261230040eec143ecf59fe5b9f19906df1ba3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?c=C3=A9lina?= Date: Tue, 29 Apr 2025 15:45:28 +0200 Subject: [PATCH 09/11] Update src/huggingface_hub/hf_api.py Co-authored-by: Lucain --- src/huggingface_hub/hf_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index af09120b42..dca5ac8b83 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -854,7 +854,7 @@ def __init__(self, **kwargs): self.inference_provider_mapping = kwargs.pop("inferenceProviderMapping", None) if self.inference_provider_mapping: self.inference_provider_mapping = { - provider: InferenceProviderMapping(**{**value, "hf_model_id": self.id}) + provider: InferenceProviderMapping(**{**value, "hf_model_id": self.id}) # little hack to simplify Inference Providers logic for provider, value in self.inference_provider_mapping.items() } From b300bce5b3b21b97ebca57d76676d780f5cd9a3f Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Tue, 29 Apr 2025 15:48:35 +0200 Subject: [PATCH 10/11] style --- src/huggingface_hub/hf_api.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index dca5ac8b83..ded959d942 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -854,7 +854,9 @@ def __init__(self, **kwargs): self.inference_provider_mapping = kwargs.pop("inferenceProviderMapping", None) if self.inference_provider_mapping: self.inference_provider_mapping = { - provider: InferenceProviderMapping(**{**value, "hf_model_id": self.id}) # little hack to simplify Inference Providers logic + provider: InferenceProviderMapping( + **{**value, "hf_model_id": self.id} + ) # little hack to simplify Inference Providers logic for provider, value in self.inference_provider_mapping.items() } From a2974bfcb43b2c102e28ac3ca33fb922df6cc1ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?c=C3=A9lina?= Date: Tue, 29 Apr 2025 16:01:35 +0200 Subject: [PATCH 11/11] add comment Co-authored-by: Lucain --- src/huggingface_hub/inference/_providers/fal_ai.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/huggingface_hub/inference/_providers/fal_ai.py b/src/huggingface_hub/inference/_providers/fal_ai.py index 3a9ed77ba8..8dd463b6b1 100644 --- a/src/huggingface_hub/inference/_providers/fal_ai.py +++ b/src/huggingface_hub/inference/_providers/fal_ai.py @@ -85,6 +85,8 @@ def _prepare_payload_as_dict( ) payload["loras"] = [{"path": lora_path, "scale": 1}] if provider_mapping_info.provider_id == "fal-ai/lora": + # little hack: fal requires the base model for stable-diffusion-based loras but not for flux-based + # See payloads in https://fal.ai/models/fal-ai/lora/api vs https://fal.ai/models/fal-ai/flux-lora/api payload["model_name"] = "stabilityai/stable-diffusion-xl-base-1.0" return payload