@@ -229,13 +229,14 @@ def generate(
229
229
generation_config .pad_token_id = eos_token_id
230
230
231
231
if generation_config is not None and generation_config .max_new_tokens is not None :
232
- max_new_tokens = generation_config .max_new_tokens
232
+ max_new_tokens = generation_config .pop ( " max_new_tokens" )
233
233
else :
234
- max_new_tokens = kwargs .get ("max_new_tokens" , None )
234
+ max_new_tokens = kwargs .pop ("max_new_tokens" , None )
235
235
236
236
return self .pipeline_parallel_generate (inputs = inputs ,
237
237
max_new_tokens = max_new_tokens ,
238
- generation_config = generation_config ,)
238
+ generation_config = generation_config ,
239
+ ** kwargs )
239
240
240
241
return original_generate (self ,
241
242
inputs = inputs ,
@@ -257,6 +258,23 @@ def pipeline_parallel_generate(self,
257
258
max_new_tokens : int = 32 ,
258
259
generation_config : Optional [GenerationConfig ] = None ,
259
260
** kwargs ):
261
+ model_kwargs = generation_config .update (** kwargs )
262
+ inputs_tensor , model_input_name , model_kwargs = self ._prepare_model_inputs (
263
+ inputs , generation_config .bos_token_id , model_kwargs
264
+ )
265
+ bs = inputs_tensor .shape [0 ]
266
+ if self .config .is_encoder_decoder :
267
+ input_ids , model_kwargs = self ._prepare_decoder_input_ids_for_generation (
268
+ batch_size = bs ,
269
+ model_input_name = model_input_name ,
270
+ model_kwargs = model_kwargs ,
271
+ decoder_start_token_id = generation_config .decoder_start_token_id ,
272
+ bos_token_id = generation_config .bos_token_id ,
273
+ device = inputs_tensor .device ,
274
+ )
275
+ else :
276
+ input_ids = inputs_tensor if model_input_name == "input_ids" \
277
+ else model_kwargs .pop ("input_ids" )
260
278
local_rank = dist .get_rank ()
261
279
pre_rank = (local_rank - 1 ) % self .pipeline_parallel_stages
262
280
next_rank = (local_rank + 1 ) % self .pipeline_parallel_stages
@@ -272,36 +290,44 @@ def pipeline_parallel_generate(self,
272
290
eos_token_id = generation_config .eos_token_id
273
291
if isinstance (eos_token_id , int ):
274
292
eos_token_id = [eos_token_id ]
275
- eos_token_id_tensor = torch .tensor (eos_token_id ).to (inputs .device ) \
293
+ eos_token_id_tensor = torch .tensor (eos_token_id ).to (input_ids .device ) \
276
294
if eos_token_id is not None else None
277
295
278
296
_input_ids = None
279
297
_past_key_values = None
280
- bs = inputs .shape [0 ]
281
- output_ids = inputs .clone ()
298
+
299
+ bs = input_ids .shape [0 ]
300
+ output_ids = input_ids .clone ()
282
301
_check_quantize_kv_cache (self , layer_start , bs )
283
302
284
303
step = 0
285
304
# keep track of which sequences are already finished
286
- unfinished_sequences = torch .ones (inputs .shape [0 ], dtype = torch .long , device = inputs .device )
305
+ unfinished_sequences = torch .ones (input_ids .shape [0 ], dtype = torch .long , device = input_ids .device )
287
306
this_peer_finished = False
288
307
while True :
289
308
if step >= max_new_tokens :
290
309
break
291
310
292
311
if _input_ids is None :
293
- _input_ids = inputs
312
+ _input_ids = input_ids
294
313
295
314
tic = time .time ()
296
315
if local_rank == 0 :
297
316
outputs = self (input_ids = _input_ids , inputs_embeds = None ,
298
- past_key_values = _past_key_values , use_cache = True )
317
+ past_key_values = _past_key_values , use_cache = True , ** model_kwargs )
299
318
else :
300
- inputs_embeds = torch .empty (_input_ids .shape + (self .config .hidden_size ,),
319
+ _inputs_shape = _input_ids .shape + (self .config .hidden_size ,)
320
+ if step == 0 and self .config .model_type == "chatglm" \
321
+ and hasattr (self .config , "vision_config" ):
322
+ # for glm-4v, image features are mapped during 1st token
323
+ # 1597 are computed according to computation process of conv
324
+ _images_feature = 1597 + _input_ids .shape [0 ] * 2 + _input_ids .shape [1 ]
325
+ _inputs_shape = (_input_ids .shape [0 ], _images_feature , self .config .hidden_size ,)
326
+ inputs_embeds = torch .empty (_inputs_shape ,
301
327
device = f'xpu:{ local_rank } ' , dtype = self .dtype )
302
328
dist .recv (inputs_embeds , src = pre_rank )
303
329
outputs = self (input_ids = None , inputs_embeds = inputs_embeds ,
304
- past_key_values = _past_key_values , use_cache = True )
330
+ past_key_values = _past_key_values , use_cache = True , ** model_kwargs )
305
331
306
332
if local_rank == self .pipeline_parallel_stages - 1 :
307
333
logits = outputs .logits
@@ -323,7 +349,8 @@ def pipeline_parallel_generate(self,
323
349
"make sure that `pad_token_id` is defined." )
324
350
next_ids = next_ids * unfinished_sequences + pad_token_id * (1 - unfinished_sequences )
325
351
326
- if self .config .model_type == "chatglm" and self .config .num_layers == 40 :
352
+ if self .config .model_type == "chatglm" and self .config .num_layers == 40 \
353
+ and not hasattr (self .config , "vision_config" ):
327
354
# for glm-4-9b-chat
328
355
if step == 0 :
329
356
value_placeholder = torch .empty_like ((outputs .past_key_values )[- 1 ][0 ])
@@ -337,7 +364,7 @@ def pipeline_parallel_generate(self,
337
364
_past_key_values = outputs .past_key_values
338
365
elif self .config .model_type in ["baichuan" , "chatglm" ] or \
339
366
(self .config .model_type == "qwen" and hasattr (self .config , "visual" )):
340
- # for baichuan2, chatglm3, Qwen-VL-Chat
367
+ # for baichuan2, chatglm3, Qwen-VL-Chat, glm-4v-9b
341
368
if local_rank != 0 :
342
369
value_placeholder = torch .empty_like ((outputs .past_key_values )[- 1 ][0 ])
343
370
past_key_values_placeholder = tuple (
0 commit comments