Skip to content

[Inference Providers] Support for LoRAs #3005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/huggingface_hub/_inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@

from huggingface_hub.errors import InferenceEndpointError, InferenceEndpointTimeoutError

from .inference._client import InferenceClient
from .inference._generated._async_client import AsyncInferenceClient
from .utils import get_session, logging, parse_datetime


if TYPE_CHECKING:
from .hf_api import HfApi

from .inference._client import InferenceClient
from .inference._generated._async_client import AsyncInferenceClient

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -138,7 +137,7 @@ def __post_init__(self) -> None:
self._populate_from_raw()

@property
def client(self) -> InferenceClient:
def client(self) -> "InferenceClient":
"""Returns a client to make predictions on this Inference Endpoint.
Returns:
Expand All @@ -152,13 +151,15 @@ def client(self) -> InferenceClient:
"Cannot create a client for this Inference Endpoint as it is not yet deployed. "
"Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again."
)
from .inference._client import InferenceClient

return InferenceClient(
model=self.url,
token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok.
)

@property
def async_client(self) -> AsyncInferenceClient:
def async_client(self) -> "AsyncInferenceClient":
"""Returns a client to make predictions on this Inference Endpoint.
Returns:
Expand All @@ -172,6 +173,8 @@ def async_client(self) -> AsyncInferenceClient:
"Cannot create a client for this Inference Endpoint as it is not yet deployed. "
"Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again."
)
from .inference._generated._async_client import AsyncInferenceClient

return AsyncInferenceClient(
model=self.url,
token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok.
Expand Down
11 changes: 10 additions & 1 deletion src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,14 +708,21 @@ def __init__(self, **kwargs):

@dataclass
class InferenceProviderMapping:
hf_model_id: str
status: Literal["live", "staging"]
provider_id: str
task: str

adapter: Optional[str] = None
adapter_weights_path: Optional[str] = None

def __init__(self, **kwargs):
self.hf_model_id = kwargs.pop("hf_model_id")
self.status = kwargs.pop("status")
self.provider_id = kwargs.pop("providerId")
self.task = kwargs.pop("task")
self.adapter = kwargs.pop("adapter", None)
self.adapter_weights_path = kwargs.pop("adapterWeightsPath", None)
self.__dict__.update(**kwargs)


Expand Down Expand Up @@ -847,7 +854,9 @@ def __init__(self, **kwargs):
self.inference_provider_mapping = kwargs.pop("inferenceProviderMapping", None)
if self.inference_provider_mapping:
self.inference_provider_mapping = {
provider: InferenceProviderMapping(**value)
provider: InferenceProviderMapping(
**{**value, "hf_model_id": self.id}
) # little hack to simplify Inference Providers logic
for provider, value in self.inference_provider_mapping.items()
}

Expand Down
53 changes: 34 additions & 19 deletions src/huggingface_hub/inference/_providers/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,24 @@
from typing import Any, Dict, Optional, Union

from huggingface_hub import constants
from huggingface_hub.hf_api import InferenceProviderMapping
from huggingface_hub.inference._common import RequestParameters
from huggingface_hub.utils import build_hf_headers, get_token, logging


logger = logging.get_logger(__name__)


# Dev purposes only.
# If you want to try to run inference for a new model locally before it's registered on huggingface.co
# for a given Inference Provider, you can add it to the following dictionary.
HARDCODED_MODEL_ID_MAPPING: Dict[str, Dict[str, str]] = {
# "HF model ID" => "Model ID on Inference Provider's side"
HARDCODED_MODEL_INFERENCE_MAPPING: Dict[str, Dict[str, InferenceProviderMapping]] = {
# "HF model ID" => InferenceProviderMapping object initialized with "Model ID on Inference Provider's side"
#
# Example:
# "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
# "Qwen/Qwen2.5-Coder-32B-Instruct": InferenceProviderMapping(hf_model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
# provider_id="Qwen2.5-Coder-32B-Instruct",
# task="conversational",
# status="live")
"cerebras": {},
"cohere": {},
"fal-ai": {},
Expand Down Expand Up @@ -61,28 +64,30 @@ def prepare_request(
api_key = self._prepare_api_key(api_key)

# mapped model from HF model ID
mapped_model = self._prepare_mapped_model(model)
provider_mapping_info = self._prepare_mapping_info(model)

# default HF headers + user headers (to customize in subclasses)
headers = self._prepare_headers(headers, api_key)

# routed URL if HF token, or direct URL (to customize in '_prepare_route' in subclasses)
url = self._prepare_url(api_key, mapped_model)
url = self._prepare_url(api_key, provider_mapping_info.provider_id)

# prepare payload (to customize in subclasses)
payload = self._prepare_payload_as_dict(inputs, parameters, mapped_model=mapped_model)
payload = self._prepare_payload_as_dict(inputs, parameters, provider_mapping_info=provider_mapping_info)
if payload is not None:
payload = recursive_merge(payload, extra_payload or {})

# body data (to customize in subclasses)
data = self._prepare_payload_as_bytes(inputs, parameters, mapped_model, extra_payload)
data = self._prepare_payload_as_bytes(inputs, parameters, provider_mapping_info, extra_payload)

# check if both payload and data are set and return
if payload is not None and data is not None:
raise ValueError("Both payload and data cannot be set in the same request.")
if payload is None and data is None:
raise ValueError("Either payload or data must be set in the request.")
return RequestParameters(url=url, task=self.task, model=mapped_model, json=payload, data=data, headers=headers)
return RequestParameters(
url=url, task=self.task, model=provider_mapping_info.provider_id, json=payload, data=data, headers=headers
)

def get_response(
self,
Expand All @@ -107,16 +112,16 @@ def _prepare_api_key(self, api_key: Optional[str]) -> str:
)
return api_key

def _prepare_mapped_model(self, model: Optional[str]) -> str:
def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping:
"""Return the mapped model ID to use for the request.
Usually not overwritten in subclasses."""
if model is None:
raise ValueError(f"Please provide an HF model ID supported by {self.provider}.")

# hardcoded mapping for local testing
if HARDCODED_MODEL_ID_MAPPING.get(self.provider, {}).get(model):
return HARDCODED_MODEL_ID_MAPPING[self.provider][model]
if HARDCODED_MODEL_INFERENCE_MAPPING.get(self.provider, {}).get(model):
return HARDCODED_MODEL_INFERENCE_MAPPING[self.provider][model]

provider_mapping = _fetch_inference_provider_mapping(model).get(self.provider)
if provider_mapping is None:
Expand All @@ -132,7 +137,7 @@ def _prepare_mapped_model(self, model: Optional[str]) -> str:
logger.warning(
f"Model {model} is in staging mode for provider {self.provider}. Meant for test purposes only."
)
return provider_mapping.provider_id
return provider_mapping

def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
"""Return the headers to use for the request.
Expand Down Expand Up @@ -168,7 +173,9 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str:
"""
return ""

def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
"""Return the payload to use for the request, as a dict.
Override this method in subclasses for customized payloads.
Expand All @@ -177,7 +184,11 @@ def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model:
return None

def _prepare_payload_as_bytes(
self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict]
self,
inputs: Any,
parameters: Dict,
provider_mapping_info: InferenceProviderMapping,
extra_payload: Optional[Dict],
) -> Optional[bytes]:
"""Return the body to use for the request, as bytes.
Expand All @@ -199,8 +210,10 @@ def __init__(self, provider: str, base_url: str):
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
return "/v1/chat/completions"

def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
return {"messages": inputs, **filter_none(parameters), "model": mapped_model}
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
return {"messages": inputs, **filter_none(parameters), "model": provider_mapping_info.provider_id}


class BaseTextGenerationTask(TaskProviderHelper):
Expand All @@ -215,8 +228,10 @@ def __init__(self, provider: str, base_url: str):
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
return "/v1/completions"

def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
return {"prompt": inputs, **filter_none(parameters), "model": mapped_model}
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
return {"prompt": inputs, **filter_none(parameters), "model": provider_mapping_info.provider_id}


@lru_cache(maxsize=None)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
from typing import Any, Dict, Optional, Union

from huggingface_hub.hf_api import InferenceProviderMapping
from huggingface_hub.inference._common import RequestParameters, _as_dict
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
from huggingface_hub.utils import logging
Expand All @@ -27,7 +28,9 @@ def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
return f"/v1/{mapped_model}"

def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
parameters = filter_none(parameters)
if "num_inference_steps" in parameters:
parameters["steps"] = parameters.pop("num_inference_steps")
Expand Down
45 changes: 35 additions & 10 deletions src/huggingface_hub/inference/_providers/fal_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import Any, Dict, Optional, Union
from urllib.parse import urlparse

from huggingface_hub import constants
from huggingface_hub.hf_api import InferenceProviderMapping
from huggingface_hub.inference._common import RequestParameters, _as_dict
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
from huggingface_hub.utils import get_session, hf_raise_for_status
Expand Down Expand Up @@ -34,7 +36,9 @@ class FalAIAutomaticSpeechRecognitionTask(FalAITask):
def __init__(self):
super().__init__("automatic-speech-recognition")

def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
if isinstance(inputs, str) and inputs.startswith(("http://", "https://")):
# If input is a URL, pass it directly
audio_url = inputs
Expand All @@ -61,14 +65,31 @@ class FalAITextToImageTask(FalAITask):
def __init__(self):
super().__init__("text-to-image")

def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
parameters = filter_none(parameters)
if "width" in parameters and "height" in parameters:
parameters["image_size"] = {
"width": parameters.pop("width"),
"height": parameters.pop("height"),
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
payload: Dict[str, Any] = {
"prompt": inputs,
**filter_none(parameters),
}
if "width" in payload and "height" in payload:
payload["image_size"] = {
"width": payload.pop("width"),
"height": payload.pop("height"),
}
return {"prompt": inputs, **parameters}
if provider_mapping_info.adapter_weights_path is not None:
lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format(
repo_id=provider_mapping_info.hf_model_id,
revision="main",
filename=provider_mapping_info.adapter_weights_path,
)
payload["loras"] = [{"path": lora_path, "scale": 1}]
if provider_mapping_info.provider_id == "fal-ai/lora":
# little hack: fal requires the base model for stable-diffusion-based loras but not for flux-based
# See payloads in https://fal.ai/models/fal-ai/lora/api vs https://fal.ai/models/fal-ai/flux-lora/api
payload["model_name"] = "stabilityai/stable-diffusion-xl-base-1.0"

return payload

def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
url = _as_dict(response)["images"][0]["url"]
Expand All @@ -79,7 +100,9 @@ class FalAITextToSpeechTask(FalAITask):
def __init__(self):
super().__init__("text-to-speech")

def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
return {"text": inputs, **filter_none(parameters)}

def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
Expand All @@ -104,7 +127,9 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str:
return f"/{mapped_model}?_subdomain=queue"
return f"/{mapped_model}"

def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
return {"prompt": inputs, **filter_none(parameters)}

def get_response(
Expand Down
Loading
Loading