1
- from __future__ import annotations
2
-
1
+ import functools
2
+ import json
3
3
import sys
4
4
import time
5
- import json
6
5
import traceback
7
- import functools
6
+ from pathlib import Path
7
+
8
8
import bottle
9
- from bottle import Bottle , HTTPResponse , request
9
+ from bottle import Bottle , HTTPResponse
10
10
11
11
imds = Bottle (autojson = True )
12
12
"""An Azure IMDS server"""
13
13
14
- from typing import TYPE_CHECKING , Any , Callable , Iterable , overload
14
+ from typing import TYPE_CHECKING , Any , Callable , Iterable , cast , overload
15
15
16
- if TYPE_CHECKING :
16
+ if not TYPE_CHECKING :
17
+ from bottle import request
18
+ else :
17
19
from typing import Protocol
18
20
19
21
class _RequestParams (Protocol ):
@@ -22,7 +24,7 @@ def __getitem__(self, key: str) -> str:
22
24
...
23
25
24
26
@overload
25
- def get (self , key : str ) -> str | None :
27
+ def get (self , key : str ) -> ' str | None' :
26
28
...
27
29
28
30
@overload
@@ -31,25 +33,30 @@ def get(self, key: str, default: str) -> str:
31
33
32
34
class _HeadersDict (dict [str , str ]):
33
35
34
- def raw (self , key : str ) -> bytes | None :
36
+ def raw (self , key : str ) -> ' bytes | None' :
35
37
...
36
38
37
39
class _Request (Protocol ):
38
- query : _RequestParams
39
- params : _RequestParams
40
- headers : _HeadersDict
41
40
42
- request : _Request
41
+ @property
42
+ def query (self ) -> _RequestParams :
43
+ ...
43
44
45
+ @property
46
+ def params (self ) -> _RequestParams :
47
+ ...
44
48
45
- def parse_qs (qs : str ) -> dict [str , str ]:
46
- return dict (bottle ._parse_qsl (qs )) # type: ignore
49
+ @property
50
+ def headers (self ) -> _HeadersDict :
51
+ ...
47
52
53
+ request = cast ('_Request' , None )
48
54
49
- def require (cond : bool , message : str ):
50
- if not cond :
51
- print (f'REQUIREMENT FAILED: { message } ' )
52
- raise bottle .HTTPError (400 , message )
55
+
56
+ def parse_qs (qs : str ) -> 'dict[str, str]' :
57
+ # Re-use the bottle.py query string parser. It's a private function, but
58
+ # we're using a fixed version of Bottle.
59
+ return dict (bottle ._parse_qsl (qs )) # type: ignore
53
60
54
61
55
62
_HandlerFuncT = Callable [
@@ -58,6 +65,7 @@ def require(cond: bool, message: str):
58
65
59
66
60
67
def handle_asserts (fn : _HandlerFuncT ) -> _HandlerFuncT :
68
+ "Convert assertion failures into HTTP 400s"
61
69
62
70
@functools .wraps (fn )
63
71
def wrapped ():
@@ -72,17 +80,10 @@ def wrapped():
72
80
return wrapped
73
81
74
82
75
- def test_flags () -> dict [str , str ]:
83
+ def test_params () -> ' dict[str, str]' :
76
84
return parse_qs (request .headers .get ('X-MongoDB-HTTP-TestParams' , '' ))
77
85
78
86
79
- def maybe_pause ():
80
- pause = int (test_flags ().get ('pause' , '0' ))
81
- if pause :
82
- print (f'Pausing for { pause } seconds' )
83
- time .sleep (pause )
84
-
85
-
86
87
@imds .get ('/metadata/identity/oauth2/token' )
87
88
@handle_asserts
88
89
def get_oauth2_token ():
@@ -91,10 +92,7 @@ def get_oauth2_token():
91
92
resource = request .query ['resource' ]
92
93
assert resource == 'https://vault.azure.net' , 'Only https://vault.azure.net is supported'
93
94
94
- flags = test_flags ()
95
- maybe_pause ()
96
-
97
- case = flags .get ('case' )
95
+ case = test_params ().get ('case' )
98
96
print ('Case is:' , case )
99
97
if case == '404' :
100
98
return HTTPResponse (status = 404 )
@@ -114,17 +112,18 @@ def get_oauth2_token():
114
112
if case == 'slow' :
115
113
return _slow ()
116
114
117
- assert case is None or case == '' , f 'Unknown HTTP test case "{ case } "'
115
+ assert case in ( None , '' ), 'Unknown HTTP test case "{}"' . format ( case )
118
116
119
117
return {
120
118
'access_token' : 'magic-cookie' ,
121
- 'expires_in' : '60 ' ,
119
+ 'expires_in' : '70 ' ,
122
120
'token_type' : 'Bearer' ,
123
121
'resource' : 'https://vault.azure.net' ,
124
122
}
125
123
126
124
127
125
def _gen_giant () -> Iterable [bytes ]:
126
+ "Generate a giant message"
128
127
yield b'{ "item": ['
129
128
for _ in range (1024 * 256 ):
130
129
yield (b'null, null, null, null, null, null, null, null, null, null, '
@@ -136,6 +135,7 @@ def _gen_giant() -> Iterable[bytes]:
136
135
137
136
138
137
def _slow () -> Iterable [bytes ]:
138
+ "Generate a very slow message"
139
139
yield b'{ "item": ['
140
140
for _ in range (1000 ):
141
141
yield b'null, '
@@ -144,6 +144,8 @@ def _slow() -> Iterable[bytes]:
144
144
145
145
146
146
if __name__ == '__main__' :
147
- print (f'RECOMMENDED: Run this script using bottle.py in the same '
148
- f'directory (e.g. [{ sys .executable } bottle.py fake_azure:imds])' )
147
+ print (
148
+ 'RECOMMENDED: Run this script using bottle.py (e.g. [{} {}/bottle.py fake_azure:imds])'
149
+ .format (sys .executable ,
150
+ Path (__file__ ).resolve ().parent ))
149
151
imds .run ()
0 commit comments