Skip to content

Commit b19ab11

Browse files
authored
(misc) Deprecate some hf-inference specific features (wait-for-model header, can't override model's task, get_model_status, list_deployed_models) (#2851)
* (draft) deprecate some hf-inference specific features * remove hf-inference specific behavior (wait for model + handle 503) * remove make sure sentence * add back sentence-similarity task but use /models instead of /pipeline/tag * async as well * fix cassettes
1 parent bf80d1c commit b19ab11

9 files changed

+44
-80
lines changed

docs/source/de/guides/inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ Das Ziel von [`InferenceClient`] ist es, die einfachste Schnittstelle zum Ausfü
107107
| | [Feature Extraction](https://huggingface.co/tasks/feature-extraction) || [`~InferenceClient.feature_extraction`] |
108108
| | [Fill Mask](https://huggingface.co/tasks/fill-mask) || [`~InferenceClient.fill_mask`] |
109109
| | [Question Answering](https://huggingface.co/tasks/question-answering) || [`~InferenceClient.question_answering`] |
110-
| | [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) || [`~InferenceClient.sentence_similarity`] |
110+
| | [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) || [`~InferenceClient.sentence_similarity`] |
111111
| | [Summarization](https://huggingface.co/tasks/summarization) || [`~InferenceClient.summarization`] |
112112
| | [Table Question Answering](https://huggingface.co/tasks/table-question-answering) || [`~InferenceClient.table_question_answering`] |
113113
| | [Text Classification](https://huggingface.co/tasks/text-classification) || [`~InferenceClient.text_classification`] |

src/huggingface_hub/inference/_client.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import base64
3636
import logging
3737
import re
38-
import time
3938
import warnings
4039
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Union, overload
4140

@@ -301,8 +300,6 @@ def _inner_post(
301300
if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
302301
request_parameters.headers["Accept"] = "image/png"
303302

304-
t0 = time.time()
305-
timeout = self.timeout
306303
while True:
307304
with _open_as_binary(request_parameters.data) as data_as_binary:
308305
try:
@@ -326,30 +323,9 @@ def _inner_post(
326323
except HTTPError as error:
327324
if error.response.status_code == 422 and request_parameters.task != "unknown":
328325
msg = str(error.args[0])
329-
print(error.response.text)
330326
if len(error.response.text) > 0:
331327
msg += f"\n{error.response.text}\n"
332-
msg += f"\nMake sure '{request_parameters.task}' task is supported by the model."
333328
error.args = (msg,) + error.args[1:]
334-
if error.response.status_code == 503:
335-
# If Model is unavailable, either raise a TimeoutError...
336-
if timeout is not None and time.time() - t0 > timeout:
337-
raise InferenceTimeoutError(
338-
f"Model not loaded on the server: {request_parameters.url}. Please retry with a higher timeout (current:"
339-
f" {self.timeout}).",
340-
request=error.request,
341-
response=error.response,
342-
) from error
343-
# ...or wait 1s and retry
344-
logger.info(f"Waiting for model to be loaded on the server: {error}")
345-
time.sleep(1)
346-
if "X-wait-for-model" not in request_parameters.headers and request_parameters.url.startswith(
347-
INFERENCE_ENDPOINT
348-
):
349-
request_parameters.headers["X-wait-for-model"] = "1"
350-
if timeout is not None:
351-
timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
352-
continue
353329
raise
354330

355331
def audio_classification(
@@ -3261,6 +3237,13 @@ def zero_shot_image_classification(
32613237
response = self._inner_post(request_parameters)
32623238
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
32633239

3240+
@_deprecate_method(
3241+
version="0.33.0",
3242+
message=(
3243+
"HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
3244+
" Use `HfApi.list_models(..., inference_provider='...')` to list warm models per provider."
3245+
),
3246+
)
32643247
def list_deployed_models(
32653248
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
32663249
) -> Dict[str, List[str]]:
@@ -3444,6 +3427,13 @@ def health_check(self, model: Optional[str] = None) -> bool:
34443427
response = get_session().get(url, headers=build_hf_headers(token=self.token))
34453428
return response.status_code == 200
34463429

3430+
@_deprecate_method(
3431+
version="0.33.0",
3432+
message=(
3433+
"HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
3434+
" Use `HfApi.model_info` to get the model status both with HF Inference API and external providers."
3435+
),
3436+
)
34473437
def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
34483438
"""
34493439
Get the status of a model hosted on the HF Inference API.

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import base64
2323
import logging
2424
import re
25-
import time
2625
import warnings
2726
from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Set, Union, overload
2827

@@ -299,8 +298,6 @@ async def _inner_post(
299298
if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
300299
request_parameters.headers["Accept"] = "image/png"
301300

302-
t0 = time.time()
303-
timeout = self.timeout
304301
while True:
305302
with _open_as_binary(request_parameters.data) as data_as_binary:
306303
# Do not use context manager as we don't want to close the connection immediately when returning
@@ -331,27 +328,6 @@ async def _inner_post(
331328
except aiohttp.ClientResponseError as error:
332329
error.response_error_payload = response_error_payload
333330
await session.close()
334-
if response.status == 422 and request_parameters.task != "unknown":
335-
error.message += f". Make sure '{request_parameters.task}' task is supported by the model."
336-
if response.status == 503:
337-
# If Model is unavailable, either raise a TimeoutError...
338-
if timeout is not None and time.time() - t0 > timeout:
339-
raise InferenceTimeoutError(
340-
f"Model not loaded on the server: {request_parameters.url}. Please retry with a higher timeout"
341-
f" (current: {self.timeout}).",
342-
request=error.request,
343-
response=error.response,
344-
) from error
345-
# ...or wait 1s and retry
346-
logger.info(f"Waiting for model to be loaded on the server: {error}")
347-
if "X-wait-for-model" not in request_parameters.headers and request_parameters.url.startswith(
348-
INFERENCE_ENDPOINT
349-
):
350-
request_parameters.headers["X-wait-for-model"] = "1"
351-
await asyncio.sleep(1)
352-
if timeout is not None:
353-
timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
354-
continue
355331
raise error
356332
except Exception:
357333
await session.close()
@@ -3325,6 +3301,13 @@ async def zero_shot_image_classification(
33253301
response = await self._inner_post(request_parameters)
33263302
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
33273303

3304+
@_deprecate_method(
3305+
version="0.33.0",
3306+
message=(
3307+
"HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
3308+
" Use `HfApi.list_models(..., inference_provider='...')` to list warm models per provider."
3309+
),
3310+
)
33283311
async def list_deployed_models(
33293312
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
33303313
) -> Dict[str, List[str]]:
@@ -3554,6 +3537,13 @@ async def health_check(self, model: Optional[str] = None) -> bool:
35543537
response = await client.get(url, proxy=self.proxies)
35553538
return response.status == 200
35563539

3540+
@_deprecate_method(
3541+
version="0.33.0",
3542+
message=(
3543+
"HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
3544+
" Use `HfApi.model_info` to get the model status both with HF Inference API and external providers."
3545+
),
3546+
)
35573547
async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
35583548
"""
35593549
Get the status of a model hosted on the HF Inference API.

src/huggingface_hub/inference/_providers/hf_inference.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,7 @@ def _prepare_url(self, api_key: str, mapped_model: str) -> str:
3838
# hf-inference provider can handle URLs (e.g. Inference Endpoints or TGI deployment)
3939
if mapped_model.startswith(("http://", "https://")):
4040
return mapped_model
41-
42-
return (
43-
# Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks.
44-
f"{self.base_url}/pipeline/{self.task}/{mapped_model}"
45-
if self.task in ("feature-extraction", "sentence-similarity")
46-
# Otherwise, we use the default endpoint
47-
else f"{self.base_url}/models/{mapped_model}"
48-
)
41+
return f"{self.base_url}/models/{mapped_model}"
4942

5043
def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
5144
if isinstance(inputs, bytes):

tests/cassettes/TestInferenceClient.test_sentence_similarity[hf-inference,sentence-similarity].yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ interactions:
1717
X-Amzn-Trace-Id:
1818
- 0434ff33-56fe-49db-9380-17b81e41f756
1919
method: POST
20-
uri: https://router.huggingface.co/hf-inference/pipeline/sentence-similarity/sentence-transformers/all-MiniLM-L6-v2
20+
uri: https://router.huggingface.co/hf-inference/models/sentence-transformers/all-MiniLM-L6-v2
2121
response:
2222
body:
2323
string: '[0.7785724997520447,0.4587624967098236,0.29062220454216003]'

tests/cassettes/test_async_sentence_similarity.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ interactions:
33
body: null
44
headers: {}
55
method: POST
6-
uri: https://router.huggingface.co/hf-inference/pipeline/sentence-similarity/sentence-transformers/all-MiniLM-L6-v2
6+
uri: https://router.huggingface.co/hf-inference/models/sentence-transformers/all-MiniLM-L6-v2
77
response:
88
body:
99
string: '[0.7785724997520447,0.4587624967098236,0.29062220454216003]'

tests/test_inference_async_client.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ def test_sync_vs_async_signatures() -> None:
300300

301301

302302
@pytest.mark.asyncio
303+
@pytest.mark.skip("Deprecated (get_model_status)")
303304
async def test_get_status_too_big_model() -> None:
304305
model_status = await AsyncInferenceClient(token=False).get_model_status("facebook/nllb-moe-54b")
305306
assert model_status.loaded is False
@@ -309,6 +310,7 @@ async def test_get_status_too_big_model() -> None:
309310

310311

311312
@pytest.mark.asyncio
313+
@pytest.mark.skip("Deprecated (get_model_status)")
312314
async def test_get_status_loaded_model() -> None:
313315
model_status = await AsyncInferenceClient(token=False).get_model_status("bigscience/bloom")
314316
assert model_status.loaded is True
@@ -318,18 +320,21 @@ async def test_get_status_loaded_model() -> None:
318320

319321

320322
@pytest.mark.asyncio
323+
@pytest.mark.skip("Deprecated (get_model_status)")
321324
async def test_get_status_unknown_model() -> None:
322325
with pytest.raises(ClientResponseError):
323326
await AsyncInferenceClient(token=False).get_model_status("unknown/model")
324327

325328

326329
@pytest.mark.asyncio
330+
@pytest.mark.skip("Deprecated (get_model_status)")
327331
async def test_get_status_model_as_url() -> None:
328332
with pytest.raises(NotImplementedError):
329333
await AsyncInferenceClient(token=False).get_model_status("https://unkown/model")
330334

331335

332336
@pytest.mark.asyncio
337+
@pytest.mark.skip("Deprecated (list_deployed_models)")
333338
async def test_list_deployed_models_single_frameworks() -> None:
334339
models_by_task = await AsyncInferenceClient().list_deployed_models("text-generation-inference")
335340
assert isinstance(models_by_task, dict)

tests/test_inference_client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,7 @@ def test_accept_header_image(self, get_session_mock: MagicMock, bytes_to_image_m
869869

870870

871871
class TestModelStatus(TestBase):
872+
@expect_deprecation("get_model_status")
872873
def test_too_big_model(self) -> None:
873874
client = InferenceClient(token=False)
874875
model_status = client.get_model_status("facebook/nllb-moe-54b")
@@ -877,6 +878,7 @@ def test_too_big_model(self) -> None:
877878
assert model_status.compute_type == "cpu"
878879
assert model_status.framework == "transformers"
879880

881+
@expect_deprecation("get_model_status")
880882
def test_loaded_model(self) -> None:
881883
client = InferenceClient(token=False)
882884
model_status = client.get_model_status("bigscience/bloom")
@@ -885,28 +887,33 @@ def test_loaded_model(self) -> None:
885887
assert isinstance(model_status.compute_type, dict) # e.g. {'gpu': {'gpu': 'a100', 'count': 8}}
886888
assert model_status.framework == "text-generation-inference"
887889

890+
@expect_deprecation("get_model_status")
888891
def test_unknown_model(self) -> None:
889892
client = InferenceClient()
890893
with pytest.raises(HfHubHTTPError):
891894
client.get_model_status("unknown/model")
892895

896+
@expect_deprecation("get_model_status")
893897
def test_model_as_url(self) -> None:
894898
client = InferenceClient()
895899
with pytest.raises(NotImplementedError):
896900
client.get_model_status("https://unkown/model")
897901

898902

899903
class TestListDeployedModels(TestBase):
904+
@expect_deprecation("list_deployed_models")
900905
@patch("huggingface_hub.inference._client.get_session")
901906
def test_list_deployed_models_main_frameworks_mock(self, get_session_mock: MagicMock) -> None:
902907
InferenceClient().list_deployed_models()
903908
assert len(get_session_mock.return_value.get.call_args_list) == len(MAIN_INFERENCE_API_FRAMEWORKS)
904909

910+
@expect_deprecation("list_deployed_models")
905911
@patch("huggingface_hub.inference._client.get_session")
906912
def test_list_deployed_models_all_frameworks_mock(self, get_session_mock: MagicMock) -> None:
907913
InferenceClient().list_deployed_models("all")
908914
assert len(get_session_mock.return_value.get.call_args_list) == len(ALL_INFERENCE_API_FRAMEWORKS)
909915

916+
@expect_deprecation("list_deployed_models")
910917
def test_list_deployed_models_single_frameworks(self) -> None:
911918
models_by_task = InferenceClient().list_deployed_models("text-generation-inference")
912919
assert isinstance(models_by_task, dict)

utils/generate_async_inference_client.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,6 @@ def _rename_to_AsyncInferenceClient(code: str) -> str:
175175
if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
176176
request_parameters.headers["Accept"] = "image/png"
177177
178-
t0 = time.time()
179-
timeout = self.timeout
180178
while True:
181179
with _open_as_binary(request_parameters.data) as data_as_binary:
182180
# Do not use context manager as we don't want to close the connection immediately when returning
@@ -205,25 +203,6 @@ def _rename_to_AsyncInferenceClient(code: str) -> str:
205203
except aiohttp.ClientResponseError as error:
206204
error.response_error_payload = response_error_payload
207205
await session.close()
208-
if response.status == 422 and request_parameters.task != "unknown":
209-
error.message += f". Make sure '{request_parameters.task}' task is supported by the model."
210-
if response.status == 503:
211-
# If Model is unavailable, either raise a TimeoutError...
212-
if timeout is not None and time.time() - t0 > timeout:
213-
raise InferenceTimeoutError(
214-
f"Model not loaded on the server: {request_parameters.url}. Please retry with a higher timeout"
215-
f" (current: {self.timeout}).",
216-
request=error.request,
217-
response=error.response,
218-
) from error
219-
# ...or wait 1s and retry
220-
logger.info(f"Waiting for model to be loaded on the server: {error}")
221-
if "X-wait-for-model" not in request_parameters.headers and request_parameters.url.startswith(INFERENCE_ENDPOINT):
222-
request_parameters.headers["X-wait-for-model"] = "1"
223-
await asyncio.sleep(1)
224-
if timeout is not None:
225-
timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
226-
continue
227206
raise error
228207
except Exception:
229208
await session.close()

0 commit comments

Comments
 (0)