Skip to content

Commit da58255

Browse files
authored
refactor inference (#1245)
1 parent 8728264 commit da58255

File tree

6 files changed

+152
-187
lines changed

6 files changed

+152
-187
lines changed

swift/llm/infer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def prepare_model_template(args: InferArguments,
243243

244244
def read_media_file(infer_kwargs: Dict[str, Any], infer_media_type: Literal['none', 'round', 'dialogue']) -> None:
245245
text = 'Input a media path or URL <<< '
246-
images = infer_kwargs.get('images', [])
246+
images = infer_kwargs.get('images') or []
247247
if infer_media_type == 'none':
248248
return
249249
if infer_media_type == 'round' or len(images) == 0:

swift/llm/utils/argument.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1299,7 +1299,7 @@ class DeployArguments(InferArguments):
12991299
def __post_init__(self):
13001300
super().__post_init__()
13011301
model_info = MODEL_MAPPING[self.model_type]
1302-
tags = model_info.get('tags', [])
1302+
tags = model_info.get('tags') or []
13031303
self.is_multimodal = 'multi-modal' in tags
13041304

13051305

swift/llm/utils/preprocess.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def preprocess(self, d: Dict[str, Any]) -> Dict[str, Any]:
197197
response = conversations[-1][self.value_key]
198198
system = sys
199199
history = h
200-
tools = d.get('tools', [])
200+
tools = d.get('tools') or []
201201
row = {'system': system, 'history': history, 'history_roles': hr}
202202
row.update({
203203
'query': query,

swift/llm/utils/utils.py

+94-111
Original file line numberDiff line numberDiff line change
@@ -542,38 +542,22 @@ def __next__(self) -> List[int]:
542542
return value
543543

544544

545-
@torch.inference_mode()
546-
def inference_stream(model: PreTrainedModel,
547-
template: Template,
548-
query: str,
549-
history: Optional[History] = None,
550-
system: Optional[str] = None,
551-
images: Optional[List[str]] = None,
552-
*,
553-
generation_config: Optional[GenerationConfig] = None,
554-
stop_words: Optional[StopWords] = None,
555-
generation_info: Optional[Dict[str, int]] = None,
556-
adapter_names: Optional[List[str]] = None,
557-
**kwargs) -> Iterator[Tuple[str, History]]:
558-
"""
559-
generation_config: Priority: generation_config > model.generation_config.
560-
"""
545+
def _prepare_inputs(model: PreTrainedModel,
546+
template: Template,
547+
query: str,
548+
history: History,
549+
system: Optional[str] = None,
550+
images: Optional[List[str]] = None,
551+
*,
552+
generation_config: Optional[GenerationConfig] = None,
553+
stop_words: Optional[StopWords] = None,
554+
adapter_names: Optional[List[str]] = None,
555+
**kwargs) -> Tuple[Dict[str, Any], Dict[str, Any], int]:
561556
if stop_words is None:
562557
stop_words = []
563-
if history is None:
564-
history = []
565-
else:
566-
history = deepcopy(history)
567558
if images is None:
568559
images = []
569560

570-
# agent support
571-
is_observation = history[-1][-1].endswith('Observation:') if history and history[-1][-1] else False
572-
if is_observation:
573-
history[-1][-1] = history[-1][-1] + query
574-
act_length = len(history[-1][-1])
575-
query = None
576-
577561
example = {
578562
'query': query,
579563
'history': history,
@@ -587,7 +571,7 @@ def inference_stream(model: PreTrainedModel,
587571
truncation_strategy = kwargs.pop('truncation_strategy', 'delete')
588572
if len(inputs) == 0 and truncation_strategy == 'delete':
589573
# input_ids exceeds `max_length`. Please increase the value of `max_length`.
590-
return '', history
574+
return {}, tokenizer_kwargs, 0
591575

592576
inputs.pop('labels', None)
593577
tokenizer = template.tokenizer
@@ -606,11 +590,8 @@ def inference_stream(model: PreTrainedModel,
606590
inputs['token_type_ids'] = torch.tensor(inputs['token_type_ids'])[None]
607591
model.eval()
608592
if generation_config is None:
609-
generation_config = getattr(model, 'generation_config', None)
593+
generation_config = getattr(model, 'generation_config')
610594
generation_config = deepcopy(generation_config)
611-
if generation_config.num_beams != 1:
612-
error_msg = 'Streaming generation does not support beam search.'
613-
raise ValueError(error_msg)
614595

615596
if tokenizer.eos_token_id is not None:
616597
generation_config.eos_token_id = tokenizer.eos_token_id
@@ -627,21 +608,69 @@ def inference_stream(model: PreTrainedModel,
627608
raise AssertionError('Current sentence length exceeds' f'the model max_length: {max_length}')
628609
if template.suffix[-1] not in stop_words:
629610
stop_words.append(template.suffix[-1])
630-
stopping_criteria = StoppingCriteriaList([StopWordsCriteria(tokenizer, stop_words, **tokenizer_kwargs)])
631611
inputs = to_device(inputs, device)
632-
if generation_info is not None:
633-
generation_info['num_prompt_tokens'] = token_len
634612
if 'inputs_embeds' in inputs:
635613
inputs.pop('input_ids', None)
636-
streamer = TokenListIteratorStreamer()
637614
if adapter_names is not None:
638615
inputs['adapter_names'] = adapter_names
639-
generation_kwargs = {
640-
'streamer': streamer,
641-
'generation_config': generation_config,
642-
'stopping_criteria': stopping_criteria,
643-
**inputs
644-
}
616+
617+
stopping_criteria = StoppingCriteriaList([StopWordsCriteria(tokenizer, stop_words, **tokenizer_kwargs)])
618+
inputs['stopping_criteria'] = stopping_criteria
619+
inputs['generation_config'] = generation_config
620+
return inputs, tokenizer_kwargs, token_len
621+
622+
623+
@torch.inference_mode()
624+
def inference_stream(model: PreTrainedModel,
625+
template: Template,
626+
query: str,
627+
history: Optional[History] = None,
628+
system: Optional[str] = None,
629+
images: Optional[List[str]] = None,
630+
*,
631+
generation_config: Optional[GenerationConfig] = None,
632+
stop_words: Optional[StopWords] = None,
633+
generation_info: Optional[Dict[str, int]] = None,
634+
adapter_names: Optional[List[str]] = None,
635+
**kwargs) -> Iterator[Tuple[str, History]]:
636+
"""
637+
generation_config: Priority: generation_config > model.generation_config.
638+
"""
639+
if history is None:
640+
history = []
641+
else:
642+
history = deepcopy(history)
643+
inputs, tokenizer_kwargs, token_len = _prepare_inputs(
644+
model,
645+
template,
646+
query,
647+
history,
648+
system,
649+
images,
650+
generation_config=generation_config,
651+
stop_words=stop_words,
652+
adapter_names=adapter_names,
653+
**kwargs)
654+
if len(inputs) == 0:
655+
return '', history
656+
if generation_info is None:
657+
generation_info = {}
658+
generation_info['num_prompt_tokens'] = token_len
659+
660+
# agent support
661+
is_observation = history[-1][-1].endswith('Observation:') if history and history[-1][-1] else False
662+
if is_observation:
663+
history[-1][-1] = history[-1][-1] + query
664+
act_length = len(history[-1][-1])
665+
query = None
666+
667+
generation_config = inputs['generation_config']
668+
if generation_config.num_beams != 1:
669+
error_msg = 'Streaming generation does not support beam search.'
670+
raise ValueError(error_msg)
671+
672+
streamer = TokenListIteratorStreamer()
673+
generation_kwargs = {'streamer': streamer, **inputs}
645674
_model_generate = model.generate
646675
if is_torch_npu_available():
647676

@@ -667,8 +696,7 @@ def _model_generate(*args, **kwargs):
667696
except StopIteration:
668697
is_finished = True
669698
generate_ids = template.get_generate_ids(torch.tensor(raw_generate_ids)[None], token_len)
670-
if generation_info is not None:
671-
generation_info['num_generated_tokens'] = len(generate_ids)
699+
generation_info['num_generated_tokens'] = len(generate_ids)
672700
response = template.generate_ids_to_response(
673701
generate_ids,
674702
is_finished,
@@ -702,58 +730,38 @@ def inference(model: PreTrainedModel,
702730
"""
703731
generation_config: Priority: generation_config > model.generation_config.
704732
"""
705-
if stop_words is None:
706-
stop_words = []
707733
if history is None:
708734
history = []
709735
else:
710736
history = deepcopy(history)
711-
if images is None:
712-
images = []
737+
inputs, tokenizer_kwargs, token_len = _prepare_inputs(
738+
model,
739+
template,
740+
query,
741+
history,
742+
system,
743+
images,
744+
generation_config=generation_config,
745+
stop_words=stop_words,
746+
adapter_names=adapter_names,
747+
**kwargs)
748+
if len(inputs) == 0:
749+
return '', history
750+
if generation_info is None:
751+
generation_info = {}
752+
generation_info['num_prompt_tokens'] = token_len
713753

754+
# agent support
714755
is_observation = history[-1][-1].endswith('Observation:') if history and history[-1][-1] else False
715756
if is_observation:
716757
history[-1][-1] = history[-1][-1] + query
717758
query = None
718759

719-
example = {
720-
'query': query,
721-
'history': history,
722-
'system': system,
723-
'images': images, # for vl. str.
724-
'tools': kwargs.pop('tools', None)
725-
}
726-
template.model = model
727-
inputs, tokenizer_kwargs = template.encode(example)
728-
729-
truncation_strategy = kwargs.pop('truncation_strategy', 'delete')
730-
if len(inputs) == 0 and truncation_strategy == 'delete':
731-
# input_ids exceeds `max_length`. Please increase the value of `max_length`.
732-
return '', history
733-
734-
inputs.pop('labels', None)
735-
tokenizer = template.tokenizer
736-
device = next(model.parameters()).device
737-
if 'input_ids' in inputs:
738-
input_ids = torch.tensor(inputs['input_ids'])[None]
739-
inputs['input_ids'] = input_ids
740-
token_len = input_ids.shape[1]
741-
if 'inputs_embeds' in inputs:
742-
inputs_embeds = inputs['inputs_embeds'][None]
743-
inputs['inputs_embeds'] = inputs_embeds
744-
token_len = inputs_embeds.shape[1]
745-
746-
inputs['attention_mask'] = torch.ones(token_len)[None]
747-
if 'token_type_ids' in inputs:
748-
inputs['token_type_ids'] = torch.tensor(inputs['token_type_ids'])[None]
749-
model.eval()
750-
if generation_config is None:
751-
generation_config = getattr(model, 'generation_config', None)
752-
generation_config = deepcopy(generation_config)
753760
if stream and not verbose:
754761
logger.warning('Please set verbose to True to support TextStreamer, or use `inference_stream.`')
755762
stream = False
756763
streamer = None
764+
tokenizer = template.tokenizer
757765
if stream:
758766
streamer = TextStreamer(tokenizer, skip_prompt=True)
759767
if verbose:
@@ -762,37 +770,12 @@ def inference(model: PreTrainedModel,
762770
print(
763771
f'{prompt_prefix}{safe_tokenizer_decode(tokenizer, input_ids[0], **tokenizer_kwargs)}{output_prefix}',
764772
end='')
765-
elif 'query' in example:
766-
query = example['query']
773+
else:
767774
print(f'[QUERY]{query}\n{output_prefix}', end='')
768-
if tokenizer.eos_token_id is not None:
769-
generation_config.eos_token_id = tokenizer.eos_token_id
770-
if tokenizer.pad_token_id is not None:
771-
generation_config.pad_token_id = tokenizer.pad_token_id
772-
if tokenizer.bos_token_id is not None:
773-
generation_config.bos_token_id = tokenizer.bos_token_id
774-
if generation_config.max_new_tokens is not None:
775-
generation_config.max_length = 20 # fix max_length, max_new_tokens warning
776-
max_length = get_max_model_len(model.config)
777-
if max_length and token_len + generation_config.max_new_tokens > max_length:
778-
generation_config.max_new_tokens = max_length - token_len
779-
if generation_config.max_new_tokens <= 0:
780-
raise AssertionError('Current sentence length exceeds' f'the model max_length: {max_length}')
781-
if template.suffix[-1] not in stop_words:
782-
stop_words.append(template.suffix[-1])
783-
stopping_criteria = StoppingCriteriaList([StopWordsCriteria(tokenizer, stop_words, **tokenizer_kwargs)])
784-
inputs = to_device(inputs, device)
785-
if generation_info is not None:
786-
generation_info['num_prompt_tokens'] = token_len
787-
if 'inputs_embeds' in inputs:
788-
inputs.pop('input_ids', None)
789-
if adapter_names is not None:
790-
inputs['adapter_names'] = adapter_names
791-
generate_ids = model.generate(
792-
streamer=streamer, generation_config=generation_config, stopping_criteria=stopping_criteria, **inputs)
775+
776+
generate_ids = model.generate(streamer=streamer, **inputs)
793777
generate_ids = template.get_generate_ids(generate_ids, token_len)
794-
if generation_info is not None:
795-
generation_info['num_generated_tokens'] = len(generate_ids)
778+
generation_info['num_generated_tokens'] = len(generate_ids)
796779
if verbose and stream is False:
797780
response = tokenizer.decode(generate_ids, **tokenizer_kwargs)
798781
print(response)

0 commit comments

Comments
 (0)