Skip to content

Commit 2b8ad87

Browse files
authored
Support pipeline parallel for glm-4v (#11545)
1 parent 7f5111a commit 2b8ad87

File tree

5 files changed

+179
-18
lines changed

5 files changed

+179
-18
lines changed

python/llm/example/GPU/Pipeline-Parallel-Inference/README.md

+15
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ To run this example with IPEX-LLM on Intel GPUs, we have some recommended requir
1717
- [Qwen/Qwen-VL-Chat](./run_qwen_vl_arc_2_card.sh)
1818
- [Qwen/CodeQwen1.5-7B-Chat](./run_qwen1.5_arc_2_card.sh)
1919
- [THUDM/glm-4-9b-chat](./run_chatglm_arc_2_card.sh)
20+
- [THUDM/glm-4v-9b](./run_glm_4v_arc_2_card.sh)
2021
- [THUDM/chatglm3-6b](./run_chatglm_arc_2_card.sh)
2122
- [baichuan-inc/Baichuan2-7B-Chat](./run_baichuan2_arc_2_card.sh)
2223
- [baichuan-inc/Baichuan2-13B-Chat](./run_baichuan2_arc_2_card.sh)
@@ -145,6 +146,20 @@ bash run_chatglm_arc_2_card.sh
145146

146147
</details>
147148

149+
<details>
150+
<summary> Show glm-4v example </summary>
151+
152+
#### Run glm-4v-9b on two Intel Arc A770
153+
154+
You could specify `--repo-id-or-model-path` in the test script to be the huggingface repo id for glm-4v-9b to be downloaded, or the path to the huggingface checkpoint folder. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine.
155+
156+
```bash
157+
pip install transformers==4.37.0 tiktoken
158+
bash run_glm_4v_arc_2_card.sh
159+
```
160+
161+
</details>
162+
148163
</details>
149164

150165
<details>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#
2+
# Copyright 2016 The BigDL Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
import os
18+
import time
19+
import torch
20+
import argparse
21+
import requests
22+
23+
from PIL import Image
24+
from ipex_llm.transformers import AutoModelForCausalLM, init_pipeline_parallel
25+
from transformers import AutoTokenizer
26+
27+
init_pipeline_parallel()
28+
29+
if __name__ == '__main__':
30+
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for THUDM/glm-4v-9b model')
31+
parser.add_argument('--repo-id-or-model-path', type=str, default="THUDM/glm-4v-9b",
32+
help='The huggingface repo id for the THUDM/glm-4v-9b model to be downloaded'
33+
', or the path to the huggingface checkpoint folder')
34+
parser.add_argument('--image-url-or-path', type=str,
35+
default='http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg',
36+
help='The URL or path to the image to infer')
37+
parser.add_argument('--prompt', type=str, default="这是什么?",
38+
help='Prompt to infer')
39+
parser.add_argument('--n-predict', type=int, default=32,
40+
help='Max tokens to predict')
41+
parser.add_argument('--low-bit', type=str, default='sym_int4', help='The quantization type the model will convert to.')
42+
parser.add_argument('--gpu-num', type=int, default=2, help='GPU number to use')
43+
44+
args = parser.parse_args()
45+
model_path = args.repo_id_or_model_path
46+
image_path = args.image_url_or_path
47+
48+
model = AutoModelForCausalLM.from_pretrained(model_path,
49+
load_in_low_bit=args.low_bit,
50+
optimize_model=True,
51+
trust_remote_code=True,
52+
use_cache=True,
53+
pipeline_parallel_stages=args.gpu_num)
54+
model = model.half()
55+
56+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
57+
local_rank = torch.distributed.get_rank()
58+
59+
query = args.prompt
60+
if os.path.exists(image_path):
61+
image = Image.open(image_path)
62+
else:
63+
image = Image.open(requests.get(image_path, stream=True).raw)
64+
65+
# here the prompt tuning refers to https://huggingface.co/THUDM/glm-4v-9b/blob/main/README.md
66+
inputs = tokenizer.apply_chat_template([{"role": "user", "image": image, "content": query}],
67+
add_generation_prompt=True,
68+
tokenize=True,
69+
return_tensors="pt",
70+
return_dict=True) # chat mode
71+
inputs = inputs.to(f'xpu:{local_rank}')
72+
all_input = [{'image': image_path}, {'text': query}]
73+
74+
# Generate predicted tokens
75+
with torch.inference_mode():
76+
gen_kwargs = {"max_new_tokens": args.n_predict, "do_sample": False,}
77+
st = time.time()
78+
outputs = model.generate(**inputs, **gen_kwargs)
79+
outputs = outputs[:, inputs['input_ids'].shape[1]:]
80+
end = time.time()
81+
if local_rank == args.gpu_num - 1:
82+
print(f'Inference time: {end-st} s')
83+
output_str = tokenizer.decode(outputs[0])
84+
print('-'*20, 'Input', '-'*20)
85+
print(f'Message: {all_input}')
86+
print('-'*20, 'Output', '-'*20)
87+
print(output_str)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#
2+
# Copyright 2016 The BigDL Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
source /opt/intel/oneapi/setvars.sh
18+
export MASTER_ADDR=127.0.0.1
19+
export MASTER_PORT=9090
20+
export FI_PROVIDER=tcp
21+
export USE_XETLA=OFF
22+
export OMP_NUM_THREADS=6
23+
if [[ $KERNEL_VERSION != *"6.5"* ]]; then
24+
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
25+
fi
26+
export TORCH_LLM_ALLREDUCE=0
27+
28+
NUM_GPUS=2 # number of used GPU
29+
# To run glm-4v-9b
30+
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
31+
glm_4v_generate.py --repo-id-or-model-path 'THUDM/glm-4v-9b' --gpu-num $NUM_GPUS --low-bit 'sym_int4'

python/llm/src/ipex_llm/transformers/models/chatglm4v.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ def chatglm4v_model_forward(
5555
# generate mode with past_key_values. the image features are already mapped
5656
if past_key_values is None:
5757
# not allow for inputs_embeds, because we want to process image feature
58-
invalidInputError(input_ids is not None and inputs_embeds is None,
59-
f"{input_ids} should not be None, {inputs_embeds} should be None.")
60-
if not is_empty(images): # multi-modality
58+
if not is_empty(images) and input_ids is not None: # multi-modality
6159
image_size: int = self.config.vision_config['image_size']
6260
patch_size: int = self.config.vision_config['patch_size']
6361
num_patches = (image_size // patch_size // 2) ** 2
@@ -99,10 +97,13 @@ def chatglm4v_model_forward(
9997
use_cache = use_cache if use_cache is not None else self.config.use_cache
10098
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
10199

102-
batch_size, seq_length = input_ids.shape
103-
104100
if inputs_embeds is None:
101+
batch_size, seq_length = input_ids.shape
105102
inputs_embeds = self.embedding(input_ids)
103+
else:
104+
batch_size, seq_length, _ = inputs_embeds.shape
105+
input_ids = torch.empty((batch_size, seq_length),
106+
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
106107

107108
if full_attention_mask is None:
108109
if (attention_mask is not None and not attention_mask.all()) or\

python/llm/src/ipex_llm/transformers/pipeline_parallel.py

+40-13
Original file line numberDiff line numberDiff line change
@@ -229,13 +229,14 @@ def generate(
229229
generation_config.pad_token_id = eos_token_id
230230

231231
if generation_config is not None and generation_config.max_new_tokens is not None:
232-
max_new_tokens = generation_config.max_new_tokens
232+
max_new_tokens = generation_config.pop("max_new_tokens")
233233
else:
234-
max_new_tokens = kwargs.get("max_new_tokens", None)
234+
max_new_tokens = kwargs.pop("max_new_tokens", None)
235235

236236
return self.pipeline_parallel_generate(inputs=inputs,
237237
max_new_tokens=max_new_tokens,
238-
generation_config=generation_config,)
238+
generation_config=generation_config,
239+
**kwargs)
239240

240241
return original_generate(self,
241242
inputs=inputs,
@@ -257,6 +258,23 @@ def pipeline_parallel_generate(self,
257258
max_new_tokens: int = 32,
258259
generation_config: Optional[GenerationConfig] = None,
259260
**kwargs):
261+
model_kwargs = generation_config.update(**kwargs)
262+
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
263+
inputs, generation_config.bos_token_id, model_kwargs
264+
)
265+
bs = inputs_tensor.shape[0]
266+
if self.config.is_encoder_decoder:
267+
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
268+
batch_size=bs,
269+
model_input_name=model_input_name,
270+
model_kwargs=model_kwargs,
271+
decoder_start_token_id=generation_config.decoder_start_token_id,
272+
bos_token_id=generation_config.bos_token_id,
273+
device=inputs_tensor.device,
274+
)
275+
else:
276+
input_ids = inputs_tensor if model_input_name == "input_ids" \
277+
else model_kwargs.pop("input_ids")
260278
local_rank = dist.get_rank()
261279
pre_rank = (local_rank - 1) % self.pipeline_parallel_stages
262280
next_rank = (local_rank + 1) % self.pipeline_parallel_stages
@@ -272,36 +290,44 @@ def pipeline_parallel_generate(self,
272290
eos_token_id = generation_config.eos_token_id
273291
if isinstance(eos_token_id, int):
274292
eos_token_id = [eos_token_id]
275-
eos_token_id_tensor = torch.tensor(eos_token_id).to(inputs.device) \
293+
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) \
276294
if eos_token_id is not None else None
277295

278296
_input_ids = None
279297
_past_key_values = None
280-
bs = inputs.shape[0]
281-
output_ids = inputs.clone()
298+
299+
bs = input_ids.shape[0]
300+
output_ids = input_ids.clone()
282301
_check_quantize_kv_cache(self, layer_start, bs)
283302

284303
step = 0
285304
# keep track of which sequences are already finished
286-
unfinished_sequences = torch.ones(inputs.shape[0], dtype=torch.long, device=inputs.device)
305+
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
287306
this_peer_finished = False
288307
while True:
289308
if step >= max_new_tokens:
290309
break
291310

292311
if _input_ids is None:
293-
_input_ids = inputs
312+
_input_ids = input_ids
294313

295314
tic = time.time()
296315
if local_rank == 0:
297316
outputs = self(input_ids=_input_ids, inputs_embeds=None,
298-
past_key_values=_past_key_values, use_cache=True)
317+
past_key_values=_past_key_values, use_cache=True, **model_kwargs)
299318
else:
300-
inputs_embeds = torch.empty(_input_ids.shape + (self.config.hidden_size,),
319+
_inputs_shape = _input_ids.shape + (self.config.hidden_size,)
320+
if step == 0 and self.config.model_type == "chatglm" \
321+
and hasattr(self.config, "vision_config"):
322+
# for glm-4v, image features are mapped during 1st token
323+
# 1597 are computed according to computation process of conv
324+
_images_feature = 1597 + _input_ids.shape[0] * 2 + _input_ids.shape[1]
325+
_inputs_shape = (_input_ids.shape[0], _images_feature, self.config.hidden_size,)
326+
inputs_embeds = torch.empty(_inputs_shape,
301327
device=f'xpu:{local_rank}', dtype=self.dtype)
302328
dist.recv(inputs_embeds, src=pre_rank)
303329
outputs = self(input_ids=None, inputs_embeds=inputs_embeds,
304-
past_key_values=_past_key_values, use_cache=True)
330+
past_key_values=_past_key_values, use_cache=True, **model_kwargs)
305331

306332
if local_rank == self.pipeline_parallel_stages - 1:
307333
logits = outputs.logits
@@ -323,7 +349,8 @@ def pipeline_parallel_generate(self,
323349
"make sure that `pad_token_id` is defined.")
324350
next_ids = next_ids * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
325351

326-
if self.config.model_type == "chatglm" and self.config.num_layers == 40:
352+
if self.config.model_type == "chatglm" and self.config.num_layers == 40 \
353+
and not hasattr(self.config, "vision_config"):
327354
# for glm-4-9b-chat
328355
if step == 0:
329356
value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0])
@@ -337,7 +364,7 @@ def pipeline_parallel_generate(self,
337364
_past_key_values = outputs.past_key_values
338365
elif self.config.model_type in ["baichuan", "chatglm"] or \
339366
(self.config.model_type == "qwen" and hasattr(self.config, "visual")):
340-
# for baichuan2, chatglm3, Qwen-VL-Chat
367+
# for baichuan2, chatglm3, Qwen-VL-Chat, glm-4v-9b
341368
if local_rank != 0:
342369
value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0])
343370
past_key_values_placeholder = tuple(

0 commit comments

Comments
 (0)