Skip to content

Commit 0d8708e

Browse files
authored
Fix qlora deploy (#1224)
1 parent 414b308 commit 0d8708e

File tree

12 files changed

+96
-62
lines changed

12 files changed

+96
-62
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ The complete list of supported models and datasets can be found at [Supported Mo
502502
|------------------------------------------------|------------------------------------------------------------------------|--------------------|----------------------------------------|------------------------------------------- |
503503
| Qwen<br>Qwen1.5<br>Qwen2 | [Tongyi Qwen 1.0 and 1.5 series models](https://github.com/QwenLM) | Chinese<br>English | 0.5B-110B<br>including quantized versions | base model<br>chat model<br>MoE model<br>code model |
504504
| ChatGLM2<br>ChatGLM3<br>Codegeex2<br>GLM4 | [Zhipu ChatGLM series models](https://github.com/THUDM) | Chinese<br>English | 6B-9B | base model<br>chat model<br>code model<br>long text model |
505-
| Baichuan/Baichuan2 | [Baichuan 1 and Baichuan 2](https://github.com/baichuan-inc) | Chinese<br>English | 7B-13B<br>including quantized versions | base model<br>chat model |
505+
| Baichuan<br>Baichuan2 | [Baichuan 1 and Baichuan 2](https://github.com/baichuan-inc) | Chinese<br>English | 7B-13B<br>including quantized versions | base model<br>chat model |
506506
| Yuan2 | [Langchao Yuan series models](https://github.com/IEIT-Yuan) | Chinese<br>English | 2B-102B | instruct model |
507507
| XVerse | [XVerse series models](https://github.com/xverse-ai) | Chinese<br>English | 7B-65B | base model<br>chat model<br>long text model<br>MoE model |
508508
| LLaMA2 | [LLaMA2 series models](https://github.com/facebookresearch/llama) | English | 7B-70B<br>including quantized versions | base model<br>chat model |

swift/llm/deploy.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest, CompletionR
157157
kwargs[key] = new_value
158158

159159
generation_config = VllmGenerationConfig(**kwargs)
160-
if generation_config.use_beam_search is True and request.stream is True:
160+
if generation_config.use_beam_search and request.stream:
161161
error_msg = 'Streaming generation does not support beam search.'
162162
raise ValueError(error_msg)
163163
tokenizer = template.tokenizer
@@ -391,16 +391,17 @@ async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionReq
391391

392392
created_time = int(time.time())
393393
adapter_kwargs = {}
394-
if request.model != _args.model_type:
395-
adapter_names = None
396-
for lora_req in _args.lora_request_list:
397-
if lora_req.lora_name == request.model:
398-
adapter_names = request.model
399-
break
400-
assert adapter_names is not None
401-
adapter_kwargs['adapter_names'] = [adapter_names]
402-
elif isinstance(model, PeftModel):
403-
adapter_kwargs['adapter_names'] = ['-']
394+
if _args.lora_request_list is not None:
395+
if request.model != _args.model_type:
396+
adapter_names = None
397+
for lora_req in _args.lora_request_list:
398+
if lora_req.lora_name == request.model:
399+
adapter_names = request.model
400+
break
401+
assert adapter_names is not None
402+
adapter_kwargs['adapter_names'] = [adapter_names]
403+
elif isinstance(model, PeftModel):
404+
adapter_kwargs['adapter_names'] = ['-'] # use base model
404405

405406
async def _generate_full():
406407
generation_info = {}

swift/llm/export.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,15 @@ def llm_export(args: ExportArguments) -> None:
121121
logger.info('Saving quantized weights...')
122122
model_cache_dir = model.model_dir
123123
save_checkpoint(
124-
None, template.tokenizer, model_cache_dir, args.ckpt_dir, args.quant_output_dir, dtype=args.dtype)
124+
None,
125+
template.tokenizer,
126+
model_cache_dir,
127+
args.ckpt_dir,
128+
args.quant_output_dir,
129+
sft_args_kwargs={
130+
'dtype': args.dtype,
131+
'quant_method': args.quant_method
132+
})
125133
logger.info(f'Successfully quantized the model and saved in {args.quant_output_dir}.')
126134
args.ckpt_dir = args.quant_output_dir
127135

swift/llm/infer.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from swift.utils import (append_to_jsonl, get_logger, get_main, get_model_info, read_multi_line, seed_everything,
1717
show_layers)
1818
from .utils import (DeployArguments, InferArguments, Template, get_additional_saved_files, get_dataset,
19-
get_model_tokenizer, get_template, inference, inference_stream, is_adapter, sample_dataset,
20-
set_generation_config)
19+
get_model_tokenizer, get_template, inference, inference_stream, is_adapter, is_quant_model,
20+
sample_dataset, set_generation_config)
2121

2222
logger = get_logger()
2323

@@ -29,6 +29,7 @@ def save_checkpoint(model: Optional[PreTrainedModel],
2929
target_dir: str,
3030
*,
3131
save_safetensors: bool = True,
32+
sft_args_kwargs: Dict[str, Any],
3233
**kwargs) -> None:
3334
if model is not None:
3435
model.save_pretrained(target_dir, safe_serialization=save_safetensors)
@@ -75,9 +76,10 @@ def save_checkpoint(model: Optional[PreTrainedModel],
7576
with open(old_sft_args_path, 'r', encoding='utf-8') as f:
7677
res = json.load(f)
7778
res['sft_type'] = 'full'
78-
dtype = kwargs.get('dtype')
79-
if dtype is not None:
80-
res['dtype'] = dtype
79+
for k in ['dtype', 'quant_method']:
80+
v = sft_args_kwargs.get(k)
81+
if v is not None:
82+
res[k] = v
8183
with open(new_sft_args_path, 'w', encoding='utf-8') as f:
8284
json.dump(res, f, ensure_ascii=False, indent=2)
8385

@@ -89,8 +91,8 @@ def merge_lora(args: InferArguments,
8991
logger.info(f'replace_if_exists: {replace_if_exists}')
9092
assert args.ckpt_dir is not None, 'args.ckpt_dir is not specified.'
9193
assert args.sft_type in ('lora', 'adalora', 'longlora'), 'Only supports lora series models'
92-
for s in ['int4', 'int8', 'awq']:
93-
assert s not in args.model_type, f'{s} model is not supported'
94+
assert not is_quant_model(
95+
args.model_type), f'{args.model_type} is a quantized model and does not support merge-lora.'
9496
if args.quantization_bit != 0:
9597
logger.warning('It is not recommended to merge quantized models, '
9698
'as this can result in performance degradation')
@@ -117,7 +119,7 @@ def merge_lora(args: InferArguments,
117119
args.ckpt_dir,
118120
merged_lora_path,
119121
save_safetensors=args.save_safetensors,
120-
dtype=args.dtype)
122+
sft_args_kwargs={'dtype': args.dtype})
121123
logger.info(f'Successfully merged LoRA and saved in {merged_lora_path}.')
122124
logger.info("Setting args.sft_type: 'full'")
123125
logger.info(f'Setting args.ckpt_dir: {merged_lora_path}')
@@ -180,6 +182,7 @@ def prepare_model_template(args: InferArguments,
180182
model_kwargs,
181183
model_id_or_path=model_id_or_path,
182184
revision=args.model_revision,
185+
quant_method=args.quant_method,
183186
**kwargs)
184187
if verbose:
185188
logger.info(f'model_config: {model.config}')
@@ -207,7 +210,13 @@ def prepare_model_template(args: InferArguments,
207210
f'args.max_model_len: {args.max_model_len}, model.max_model_len: {model.max_model_len}')
208211
# Preparing LoRA
209212
if is_adapter(args.sft_type) and args.ckpt_dir is not None:
213+
if is_quant_model(args.model_type, model):
214+
# gptq awq does not support lora switching
215+
args.lora_request_list = None
216+
logger.warning('The current model does not support LoRA switching. '
217+
f'Setting args.lora_request_list: {args.lora_request_list}')
210218
if isinstance(args, DeployArguments) and args.lora_request_list is not None:
219+
logger.info(f'args.lora_request_list: {args.lora_request_list}')
211220
for lora_request in args.lora_request_list:
212221
model = Swift.from_pretrained(
213222
model, lora_request.lora_local_path, lora_request.lora_name, inference_mode=True)
@@ -499,7 +508,7 @@ def llm_infer(args: InferArguments) -> Dict[str, List[Dict[str, Any]]]:
499508
kwargs['tools'] = tools
500509
kwargs['truncation_strategy'] = args.truncation_strategy
501510
if args.infer_backend == 'vllm':
502-
assert args.stream is True
511+
assert args.stream
503512
if args.verbose:
504513
print(f"[QUERY]{data['query']}\n[RESPONSE]", end='')
505514
gen = inference_stream_vllm(llm_engine, template, [kwargs], lora_request=lora_request)

swift/llm/rlhf.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,6 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
9494
kwargs['use_flash_attn'] = args.use_flash_attn
9595
if args.local_repo_path:
9696
kwargs['local_repo_path'] = args.local_repo_path
97-
if args.quant_method == 'awq':
98-
kwargs['is_awq'] = True
99-
elif args.quant_method == 'aqlm':
100-
kwargs['is_aqlm'] = True
101-
elif args.quant_method == 'gptq':
102-
kwargs['is_gptq'] = True
10397

10498
if args.rope_scaling:
10599
kwargs['rope_scaling'] = args.rope_scaling
@@ -111,6 +105,7 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
111105
model_kwargs,
112106
model_id_or_path=args.model_id_or_path,
113107
revision=args.model_revision,
108+
quant_method=args.quant_method,
114109
is_training=True,
115110
**kwargs)
116111
logger.info(f'model_config: {model.config}')
@@ -155,6 +150,7 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
155150
model_kwargs,
156151
model_id_or_path=args.ref_model_id_or_path,
157152
revision=args.model_revision,
153+
quant_method=args.quant_method,
158154
**kwargs)
159155
else:
160156
ref_model = None

swift/llm/sft.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -100,25 +100,24 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
100100
kwargs['use_flash_attn'] = args.use_flash_attn
101101
if args.local_repo_path:
102102
kwargs['local_repo_path'] = args.local_repo_path
103-
if args.quant_method == 'awq':
104-
kwargs['is_awq'] = True
105-
elif args.quant_method == 'aqlm':
106-
kwargs['is_aqlm'] = True
107-
elif args.quant_method == 'gptq':
108-
kwargs['is_gptq'] = True
109103

110104
if args.rope_scaling:
111105
kwargs['rope_scaling'] = args.rope_scaling
112-
kwargs['max_length'] = args.max_length
113106

114107
model, tokenizer = get_model_tokenizer(
115108
args.model_type,
116109
args.torch_dtype,
117110
model_kwargs,
118111
model_id_or_path=args.model_id_or_path,
119112
revision=args.model_revision,
113+
quant_method=args.quant_method,
120114
is_training=True,
121115
**kwargs)
116+
for k in ['gptq', 'awq', 'aqlm']:
117+
if getattr(model, f'is_{k}', None):
118+
args.quant_method = k
119+
logger.info(f'Setting args.quant_method: {args.quant_method}')
120+
break
122121
logger.info(f'model_config: {model.config}')
123122
generation_config = GenerationConfig(
124123
max_new_tokens=args.max_new_tokens,

swift/llm/utils/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
get_template, register_template)
2222
from .utils import (LazyLLMDataset, LLMDataset, dataset_map, download_dataset, find_all_linears, find_embedding,
2323
find_ln, get_max_model_len, get_time_info, history_to_messages, inference, inference_stream,
24-
is_vllm_available, limit_history_length, messages_join_observation, messages_to_history,
25-
print_example, safe_tokenizer_decode, set_generation_config, sort_by_max_length, stat_dataset,
26-
to_device)
24+
is_quant_model, is_vllm_available, limit_history_length, messages_join_observation,
25+
messages_to_history, print_example, safe_tokenizer_decode, set_generation_config,
26+
sort_by_max_length, stat_dataset, to_device)
2727

2828
try:
2929
if is_vllm_available():

swift/llm/utils/argument.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from .model import (MODEL_MAPPING, dtype_mapping, get_additional_saved_files, get_default_lora_target_modules,
3131
get_default_template_type)
3232
from .template import TEMPLATE_MAPPING
33-
from .utils import is_vllm_available
33+
from .utils import is_quant_model, is_vllm_available
3434

3535
logger = get_logger()
3636

@@ -675,15 +675,15 @@ def load_from_checkpoint(self) -> None:
675675
with open(sft_args_path, 'r', encoding='utf-8') as f:
676676
sft_args = json.load(f)
677677
imported_keys = [
678-
'model_type', 'model_revision', 'quantization_bit', 'dtype', 'bnb_4bit_comp_dtype', 'bnb_4bit_quant_type',
679-
'bnb_4bit_use_double_quant', 'model_id_or_path'
678+
'model_type', 'model_revision', 'quant_method', 'quantization_bit', 'dtype', 'bnb_4bit_comp_dtype',
679+
'bnb_4bit_quant_type', 'bnb_4bit_use_double_quant', 'model_id_or_path'
680680
]
681681

682682
for key in imported_keys:
683683
value = getattr(self, key)
684684
if key in {'dtype', 'bnb_4bit_comp_dtype'} and value != 'AUTO':
685685
continue
686-
if key in {'model_type', 'model_revision', 'model_id_or_path'} and value is not None:
686+
if key in {'model_type', 'model_revision', 'model_id_or_path', 'quant_method'} and value is not None:
687687
continue
688688
setattr(self, key, sft_args.get(key))
689689

@@ -820,8 +820,9 @@ def __post_init__(self) -> None:
820820
'lora does not support `freeze_parameters`, please set `--sft_type full`')
821821
assert len(self.additional_trainable_parameters) == 0, (
822822
'lora does not support `additional_trainable_parameters`, please set `--sft_type full`')
823-
if 'int4' in self.model_type or 'int8' in self.model_type or 'awq' in self.model_type:
824-
assert self.quantization_bit == 0, 'int4, int8 or awq models do not need to be quantized again.'
823+
if is_quant_model(self.model_type):
824+
assert self.quantization_bit == 0, (
825+
f'{self.model_type} is already a quantized model and does not need to be quantized again.')
825826
if self.learning_rate is None:
826827
self.learning_rate = 1e-4
827828
if self.save_only_model is None:
@@ -1026,7 +1027,7 @@ def _init_training_args(self) -> None:
10261027
self.training_args = training_args
10271028

10281029
def _handle_pai_compat(self) -> None:
1029-
assert is_pai_training_job() is True
1030+
assert is_pai_training_job()
10301031
logger.info('Handle pai compat...')
10311032
pai_tensorboard_dir = get_pai_tensorboard_dir()
10321033
if self.logging_dir is None and pai_tensorboard_dir is not None:
@@ -1075,7 +1076,8 @@ class InferArguments(ArgumentsBase):
10751076
model_name: List[str] = field(default_factory=lambda: [None, None], metadata={'help': "e.g. ['小黄', 'Xiao Huang']"})
10761077
model_author: List[str] = field(
10771078
default_factory=lambda: [None, None], metadata={'help': "e.g. ['魔搭', 'ModelScope']"})
1078-
quant_method: Literal['bnb', 'hqq', 'eetq'] = None
1079+
# 'awq', 'gptq', 'aqlm' are used for inference on pre-quantized models.
1080+
quant_method: Literal['bnb', 'hqq', 'eetq', 'awq', 'gptq', 'aqlm'] = None
10791081
quantization_bit: Literal[0, 1, 2, 3, 4, 8] = 0 # hqq: 1,2,3,4,8. bnb: 4,8
10801082
hqq_axis: Literal[0, 1] = 0
10811083
hqq_dynamic_config_path: Optional[str] = None
@@ -1211,14 +1213,13 @@ def handle_infer_backend(self):
12111213
if not support_vllm:
12121214
logger.warning(f'vllm not support `{self.model_type}`')
12131215
if self.sft_type == 'lora' and not self.vllm_enable_lora:
1214-
assert self.merge_lora is True, ('To use VLLM, you need to provide the complete weight parameters. '
1215-
'Please set `--merge_lora true`.')
1216+
assert self.merge_lora, ('To use VLLM, you need to provide the complete weight parameters. '
1217+
'Please set `--merge_lora true`.')
12161218
if (self.infer_backend == 'vllm' and self.vllm_enable_lora
12171219
or self.infer_backend == 'pt' and isinstance(self, DeployArguments) and self.sft_type == 'lora'):
12181220
assert self.ckpt_dir is not None
12191221
self.lora_modules.append(f'default-lora={self.ckpt_dir}')
12201222
self.lora_request_list = _parse_lora_modules(self.lora_modules, self.infer_backend == 'vllm')
1221-
logger.info(f'args.lora_request_list: {self.lora_request_list}')
12221223

12231224
template_info = TEMPLATE_MAPPING[self.template_type]
12241225
if self.num_beams != 1:
@@ -1236,7 +1237,7 @@ def load_from_ckpt_dir(self) -> None:
12361237
with open(sft_args_path, 'r', encoding='utf-8') as f:
12371238
sft_args = json.load(f)
12381239
imported_keys = [
1239-
'model_type', 'model_revision', 'sft_type', 'template_type', 'system', 'quantization_bit',
1240+
'model_type', 'model_revision', 'sft_type', 'template_type', 'system', 'quant_method', 'quantization_bit',
12401241
'bnb_4bit_comp_dtype', 'bnb_4bit_quant_type', 'bnb_4bit_use_double_quant', 'rope_scaling'
12411242
]
12421243
if self.load_dataset_config:
@@ -1248,7 +1249,7 @@ def load_from_ckpt_dir(self) -> None:
12481249
value = getattr(self, key)
12491250
if key in {'dataset', 'val_dataset'} and len(value) > 0:
12501251
continue
1251-
if key in {'dataset_test_ratio', 'system'} and value is not None:
1252+
if key in {'dataset_test_ratio', 'system', 'quant_method'} and value is not None:
12521253
continue
12531254
setattr(self, key, sft_args.get(key))
12541255

0 commit comments

Comments
 (0)