12
12
from unittest import mock # python 3.3 and above
13
13
14
14
15
+ def mock_client_post (client , post_mock ):
16
+ # huggingface-hub==0.28.0 deprecates the `post` method
17
+ # so patch `_inner_post` instead
18
+ client .post = post_mock
19
+ client ._inner_post = post_mock
20
+
21
+
15
22
@pytest .mark .parametrize (
16
23
"send_default_pii, include_prompts, details_arg" ,
17
24
itertools .product ([True , False ], repeat = 3 ),
@@ -28,7 +35,7 @@ def test_nonstreaming_chat_completion(
28
35
29
36
client = InferenceClient ("some-model" )
30
37
if details_arg :
31
- client . post = mock .Mock (
38
+ post_mock = mock .Mock (
32
39
return_value = b"""[{
33
40
"generated_text": "the model response",
34
41
"details": {
@@ -40,9 +47,11 @@ def test_nonstreaming_chat_completion(
40
47
}]"""
41
48
)
42
49
else :
43
- client . post = mock .Mock (
50
+ post_mock = mock .Mock (
44
51
return_value = b'[{"generated_text": "the model response"}]'
45
52
)
53
+ mock_client_post (client , post_mock )
54
+
46
55
with start_transaction (name = "huggingface_hub tx" ):
47
56
response = client .text_generation (
48
57
prompt = "hello" ,
@@ -84,7 +93,8 @@ def test_streaming_chat_completion(
84
93
events = capture_events ()
85
94
86
95
client = InferenceClient ("some-model" )
87
- client .post = mock .Mock (
96
+
97
+ post_mock = mock .Mock (
88
98
return_value = [
89
99
b"""data:{
90
100
"token":{"id":1, "special": false, "text": "the model "}
@@ -95,6 +105,8 @@ def test_streaming_chat_completion(
95
105
}""" ,
96
106
]
97
107
)
108
+ mock_client_post (client , post_mock )
109
+
98
110
with start_transaction (name = "huggingface_hub tx" ):
99
111
response = list (
100
112
client .text_generation (
@@ -131,7 +143,9 @@ def test_bad_chat_completion(sentry_init, capture_events):
131
143
events = capture_events ()
132
144
133
145
client = InferenceClient ("some-model" )
134
- client .post = mock .Mock (side_effect = OverloadedError ("The server is overloaded" ))
146
+ post_mock = mock .Mock (side_effect = OverloadedError ("The server is overloaded" ))
147
+ mock_client_post (client , post_mock )
148
+
135
149
with pytest .raises (OverloadedError ):
136
150
client .text_generation (prompt = "hello" )
137
151
@@ -147,13 +161,15 @@ def test_span_origin(sentry_init, capture_events):
147
161
events = capture_events ()
148
162
149
163
client = InferenceClient ("some-model" )
150
- client . post = mock .Mock (
164
+ post_mock = mock .Mock (
151
165
return_value = [
152
166
b"""data:{
153
167
"token":{"id":1, "special": false, "text": "the model "}
154
168
}""" ,
155
169
]
156
170
)
171
+ mock_client_post (client , post_mock )
172
+
157
173
with start_transaction (name = "huggingface_hub tx" ):
158
174
list (
159
175
client .text_generation (
0 commit comments