8
8
9
9
class TestPagination (unittest .TestCase ):
10
10
@patch ("huggingface_hub.utils._pagination.get_session" )
11
+ @patch ("huggingface_hub.utils._pagination.http_backoff" )
11
12
@patch ("huggingface_hub.utils._pagination.hf_raise_for_status" )
12
13
@handle_injection_in_test
13
- def test_mocked_paginate (self , mock_get_session : Mock , mock_hf_raise_for_status : Mock ) -> None :
14
+ def test_mocked_paginate (
15
+ self , mock_get_session : Mock , mock_http_backoff : Mock , mock_hf_raise_for_status : Mock
16
+ ) -> None :
14
17
mock_get = mock_get_session ().get
15
18
mock_params = Mock ()
16
19
mock_headers = Mock ()
@@ -33,31 +36,32 @@ def test_mocked_paginate(self, mock_get_session: Mock, mock_hf_raise_for_status:
33
36
# Mock response
34
37
mock_get .side_effect = [
35
38
mock_response_page_1 ,
39
+ ]
40
+ mock_http_backoff .side_effect = [
36
41
mock_response_page_2 ,
37
42
mock_response_page_3 ,
38
43
]
39
44
40
45
results = paginate ("url" , params = mock_params , headers = mock_headers )
41
46
42
47
# Requests are made only when generator is yielded
43
- self . assertEqual ( mock_get .call_count , 0 )
48
+ assert mock_get .call_count == 0
44
49
45
50
# Results after concatenating pages
46
- self . assertListEqual ( list (results ), [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ])
51
+ assert list (results ) == [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ]
47
52
48
53
# All pages requested: 3 requests, 3 raise for status
49
- self .assertEqual (mock_get .call_count , 3 )
50
- self .assertEqual (mock_hf_raise_for_status .call_count , 3 )
54
+ # First request with `get_session.get` (we want at least 1 request to succeed correctly) and 2 with `http_backoff`
55
+ assert mock_get .call_count == 1
56
+ assert mock_http_backoff .call_count == 2
57
+ assert mock_hf_raise_for_status .call_count == 3
51
58
52
59
# Params not passed to next pages
53
- self .assertListEqual (
54
- mock_get .call_args_list ,
55
- [
56
- call ("url" , params = mock_params , headers = mock_headers ),
57
- call ("url_p2" , headers = mock_headers ),
58
- call ("url_p3" , headers = mock_headers ),
59
- ],
60
- )
60
+ assert mock_get .call_args_list == [call ("url" , params = mock_params , headers = mock_headers )]
61
+ assert mock_http_backoff .call_args_list == [
62
+ call ("GET" , "url_p2" , max_retries = 20 , retry_on_status_codes = 429 , headers = mock_headers ),
63
+ call ("GET" , "url_p3" , max_retries = 20 , retry_on_status_codes = 429 , headers = mock_headers ),
64
+ ]
61
65
62
66
def test_paginate_github_api (self ) -> None :
63
67
# Real test: paginate over huggingface repos on Github
0 commit comments