@@ -103,6 +103,116 @@ def test_serving_chat_should_set_correct_max_tokens():
103
103
104
104
assert mock_engine .generate .call_args .args [1 ].max_tokens == 10
105
105
106
+ # Setting server's max_tokens in the generation_config.json
107
+ # lower than context_window - prompt_tokens
108
+ mock_model_config = MockModelConfig ()
109
+ mock_model_config .diff_sampling_param = {
110
+ "max_tokens" : 10 # Setting server-side max_tokens limit
111
+ }
112
+
113
+ # Reinitialize the engine with new settings
114
+ mock_engine = MagicMock (spec = MQLLMEngineClient )
115
+ mock_engine .get_tokenizer .return_value = get_tokenizer (MODEL_NAME )
116
+ mock_engine .errored = False
117
+
118
+ # Initialize the serving chat
119
+ models = OpenAIServingModels (engine_client = mock_engine ,
120
+ base_model_paths = BASE_MODEL_PATHS ,
121
+ model_config = mock_model_config )
122
+ serving_chat = OpenAIServingChat (mock_engine ,
123
+ mock_model_config ,
124
+ models ,
125
+ response_role = "assistant" ,
126
+ chat_template = CHAT_TEMPLATE ,
127
+ chat_template_content_format = "auto" ,
128
+ request_logger = None )
129
+
130
+ # Test Case 1: No max_tokens specified in request
131
+ req = ChatCompletionRequest (
132
+ model = MODEL_NAME ,
133
+ messages = [{
134
+ "role" : "user" ,
135
+ "content" : "what is 1+1?"
136
+ }],
137
+ guided_decoding_backend = "outlines" ,
138
+ )
139
+
140
+ with suppress (Exception ):
141
+ asyncio .run (serving_chat .create_chat_completion (req ))
142
+
143
+ assert mock_engine .generate .call_args .args [1 ].max_tokens == 10
144
+
145
+ # Test Case 2: Request's max_tokens set higher than server accepts
146
+ req .max_tokens = 15
147
+
148
+ with suppress (Exception ):
149
+ asyncio .run (serving_chat .create_chat_completion (req ))
150
+
151
+ assert mock_engine .generate .call_args .args [1 ].max_tokens == 10
152
+
153
+ # Test Case 3: Request's max_tokens set lower than server accepts
154
+ req .max_tokens = 5
155
+
156
+ with suppress (Exception ):
157
+ asyncio .run (serving_chat .create_chat_completion (req ))
158
+
159
+ assert mock_engine .generate .call_args .args [1 ].max_tokens == 5
160
+
161
+ # Setting server's max_tokens in the generation_config.json
162
+ # higher than context_window - prompt_tokens
163
+ mock_model_config = MockModelConfig ()
164
+ mock_model_config .diff_sampling_param = {
165
+ "max_tokens" : 200 # Setting server-side max_tokens limit
166
+ }
167
+
168
+ # Reinitialize the engine with new settings
169
+ mock_engine = MagicMock (spec = MQLLMEngineClient )
170
+ mock_engine .get_tokenizer .return_value = get_tokenizer (MODEL_NAME )
171
+ mock_engine .errored = False
172
+
173
+ # Initialize the serving chat
174
+ models = OpenAIServingModels (engine_client = mock_engine ,
175
+ base_model_paths = BASE_MODEL_PATHS ,
176
+ model_config = mock_model_config )
177
+ serving_chat = OpenAIServingChat (mock_engine ,
178
+ mock_model_config ,
179
+ models ,
180
+ response_role = "assistant" ,
181
+ chat_template = CHAT_TEMPLATE ,
182
+ chat_template_content_format = "auto" ,
183
+ request_logger = None )
184
+
185
+ # Test case 1: No max_tokens specified, defaults to context_window
186
+ req = ChatCompletionRequest (
187
+ model = MODEL_NAME ,
188
+ messages = [{
189
+ "role" : "user" ,
190
+ "content" : "what is 1+1?"
191
+ }],
192
+ guided_decoding_backend = "outlines" ,
193
+ )
194
+
195
+ with suppress (Exception ):
196
+ asyncio .run (serving_chat .create_chat_completion (req ))
197
+
198
+ assert mock_engine .generate .call_args .args [1 ].max_tokens == 93
199
+
200
+ # Test Case 2: Request's max_tokens set higher than server accepts
201
+ req .max_tokens = 100
202
+
203
+ with suppress (Exception ):
204
+ asyncio .run (serving_chat .create_chat_completion (req ))
205
+
206
+ assert mock_engine .generate .call_args .args [1 ].max_tokens == 93
207
+
208
+ # Test Case 3: Request's max_tokens set lower than server accepts
209
+ req .max_tokens = 5
210
+
211
+ with suppress (Exception ):
212
+ asyncio .run (serving_chat .create_chat_completion (req ))
213
+
214
+ assert mock_engine .generate .call_args .args [1 ].max_tokens == 5
215
+
106
216
107
217
def test_serving_chat_could_load_correct_generation_config ():
108
218
0 commit comments