Skip to content

Commit 851405c

Browse files
yiliu30Yi4LiuKepingYankepingyanmfylcek
authored
Re-quantize the Official FP8 Model Using INC (vllm-project#959)
Re-quantize the Official FP8 Model Using INC - Docs: https://github.com/yiliu30/vllm-fork/blob/r1-woq-requant/scripts/DEEPSEEK_R1_ON_GAUDI.md#requantize-the-official-fp8-model-using-inc gsm8k result: ``` bash Running generate_until requests: 100%|█████████████| 64/64 [06:10<00:00, 5.78s/it] |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9688|± |0.0219| | | |strict-match | 5|exact_match|↑ |0.9531|± |0.0266| ``` cc @thuang6 @yangulei --------- Signed-off-by: Yi Liu <[email protected]> Signed-off-by: Yi <[email protected]> Signed-off-by: yiliu30 <[email protected]> Co-authored-by: Yi Liu <[email protected]> Co-authored-by: KepingYan <[email protected]> Co-authored-by: kepingyan <[email protected]> Co-authored-by: Marceli Fylcek <[email protected]> Co-authored-by: Youlei Yang <[email protected]> Co-authored-by: Wei Lin <[email protected]> Co-authored-by: Tony Lin <[email protected]> Co-authored-by: Bob Zhu <[email protected]> Co-authored-by: Bob Zhu <[email protected]>
1 parent e5da58c commit 851405c

14 files changed

+544
-22
lines changed

requirements-hpu.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@ tabulate
99
setuptools>=61
1010
setuptools-scm>=8
1111
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@ecb60e4
12+
# FIXME: (Yi) Replace it with the INC 3.4
13+
git+https://github.com/intel/neural-compressor.git@r1-woq

scripts/DEEPSEEK_R1_ON_GAUDI.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,29 @@ ray start --address='${head_ip}:6379' --resources='{"HPU": 8, "TPU": 0}'
9393
python scripts/run_example_tp_2nodes.py --model ${YOUR_PATH}/DeepSeek-R1-static
9494
```
9595

96+
# Requantize the Official FP8 Model Using INC
97+
- INC: https://github.com/intel/neural-compressor/tree/r1-woq
98+
99+
- Calibration
100+
> [!Note]
101+
> This step will take a while. You can skip it by downloading the pre-calibration result.
102+
>
103+
> `huggingface-cli download Yi30/inc-woq-full-pile-512-1024-331 --local-dir ./scripts/nc_workspace_measure_kvache`
104+
105+
```bash
106+
export OFFICIAL_FP8_MODEL=deepseek-ai/DeepSeek-R1
107+
cd ./scripts
108+
VLLM_REQUANT_FP8_INC=1 QUANT_CONFIG=inc_measure_with_fp8kv_config.json VLLM_ENABLE_RUNTIME_DEQUANT=1 python run_example_tp.py --model ${OFFICIAL_FP8_MODEL} --tokenizer ${OFFICIAL_FP8_MODEL} --osl 32 --max_num_seqs 1 --nprompts 512 --dataset pile
109+
```
110+
111+
- Quantization
112+
```bash
113+
cd ./scripts
114+
VLLM_REQUANT_FP8_INC=1 QUANT_CONFIG=inc_quant_with_fp8kv_config.json VLLM_ENABLE_RUNTIME_DEQUANT=1 python run_example_tp.py --model ${OFFICIAL_FP8_MODEL} --tokenizer ${OFFICIAL_FP8_MODEL} --osl 32 --max_num_seqs 1 --fp8_kv_cache
115+
```
116+
117+
- Evaluation
118+
```bash
119+
cd ./scripts
120+
VLLM_REQUANT_FP8_INC=1 QUANT_CONFIG=inc_quant_with_fp8kv_config.json VLLM_ENABLE_RUNTIME_DEQUANT=1 python run_lm_eval.py --model ${OFFICIAL_FP8_MODEL} --tokenizer ${OFFICIAL_FP8_MODEL} --fp8_kv_cache -l 64 --batch_size 1
121+
```
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# FIXME: (Yi) remove it before merge
2+
#!/bin/bash
3+
tp_parrallel=8
4+
in_len=1024
5+
out_len=1024
6+
multi_step=1
7+
total_len=$((in_len + out_len))
8+
# if total_len is not multiple of 128, round up to the next multiple of 128
9+
if [ $((total_len % 128)) -ne 0 ]; then
10+
echo 'round up for 128'
11+
total_len=$(((total_len / 128 + 1) * 128 ))
12+
fi
13+
ep_size=8
14+
moe_n_slice=1
15+
gpu_utils=0.92
16+
bs=448
17+
num_prompts=448
18+
request_rate=inf
19+
log_name="[inc-331-moe-op-maxabs_hw-scalar-online-gaudi3-${gpu_utils}util-TPparallel${tp_parrallel}-EP${ep_size}-loop${moe_n_slice}moegroups-multistep${multi_step}_nprompt${num_prompts}_rrate${request_rate}_bs${bs}_i${in_len}_o${out_len}_mdllen${total_len}"
20+
21+
VLLM_DECODE_BLOCK_BUCKET_MIN=$((in_len * bs / 128))
22+
VLLM_DECODE_BLOCK_BUCKET_MAX=$((total_len * bs / 128 + 128))
23+
# model="/data/models/DeepSeek-R1-static/"
24+
# tokenizer="/data/models/DeepSeek-R1-static/"
25+
model="/data/models/DeepSeek-R1/"
26+
tokenizer="/data/models/DeepSeek-R1/"
27+
model_name="DeepSeek-R1"
28+
29+
30+
QUANT_CONFIG="inc_quant_with_fp8kv_config.json" \
31+
VLLM_REQUANT_FP8_INC=1 \
32+
VLLM_ENABLE_RUNTIME_DEQUANT=1 \
33+
VLLM_DELAYED_SAMPLING=true \
34+
HABANA_VISIBLE_DEVICES="ALL" \
35+
VLLM_MOE_N_SLICE=${moe_n_slice} \
36+
VLLM_EP_SIZE=${ep_size} \
37+
VLLM_MLA_DISABLE_REQUANTIZATION=1 \
38+
PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
39+
PT_HPU_WEIGHT_SHARING=0 \
40+
VLLM_PROMPT_BS_BUCKET_MIN=1 \
41+
VLLM_PROMPT_BS_BUCKET_MAX=16 \
42+
VLLM_PROMPT_SEQ_BUCKET_MIN=${in_len} \
43+
VLLM_PROMPT_SEQ_BUCKET_MAX=${in_len} \
44+
VLLM_DECODE_BS_BUCKET_MIN=${bs} \
45+
VLLM_DECODE_BS_BUCKET_MAX=${bs} \
46+
VLLM_DECODE_BLOCK_BUCKET_MIN=${VLLM_DECODE_BLOCK_BUCKET_MIN} \
47+
VLLM_DECODE_BLOCK_BUCKET_MAX=${VLLM_DECODE_BLOCK_BUCKET_MAX} \
48+
python -m vllm.entrypoints.openai.api_server \
49+
--port 8080 \
50+
--model ${model} \
51+
--tensor-parallel-size ${tp_parrallel} \
52+
--max-num-seqs ${bs} \
53+
--disable-log-requests \
54+
--dtype bfloat16 \
55+
--use-v2-block-manager \
56+
--num_scheduler_steps ${multi_step}\
57+
--max-model-len 4096 \
58+
--distributed_executor_backend mp \
59+
--gpu_memory_utilization ${gpu_utils} \
60+
--kv_cache_dtype "fp8_inc" \
61+
--trust_remote_code 2>&1 | tee benchmark_logs/${log_name}_serving.log &
62+
pid=$(($!-1))
63+
64+
until [[ "$n" -ge 1000 ]] || [[ $ready == true ]]; do
65+
n=$((n+1))
66+
if grep -q "Started server process" benchmark_logs/${log_name}_serving.log; then
67+
break
68+
fi
69+
sleep 5s
70+
done
71+
sleep 10s
72+
echo ${pid}
73+
74+
hl-smi -l > tee benchmark_logs/${log_name}_smi.log &
75+
hl_pid=$(($!-1))
76+
77+
78+
start_time=$(date +%s)
79+
echo "Start to benchmark"
80+
python ../benchmarks/benchmark_serving.py --backend vllm --model ${model} --tokenizer ${tokenizer} --dataset-name sonnet --dataset-path ../benchmarks/sonnet.txt --request-rate ${request_rate} --num-prompts ${num_prompts} --port 8080 --sonnet-input-len ${in_len} --sonnet-output-len ${out_len} --sonnet-prefix-len 100 2>&1 | tee benchmark_logs/${log_name}_run1.log
81+
end_time=$(date +%s)
82+
echo "Time elapsed: $((end_time - start_time))s"
83+
84+
sleep 10
85+
86+
# start_time=$(date +%s)
87+
# echo "Start to benchmark"
88+
# python benchmarks/benchmark_serving.py --backend vllm --model ${model} --tokenizer ${tokenizer} --dataset-name sonnet --dataset-path benchmarks/sonnet.txt --request-rate ${request_rate} --num-prompts ${num_prompts} --port 8080 --sonnet-input-len ${in_len} --sonnet-output-len ${out_len} --sonnet-prefix-len 100 2>&1 | tee benchmark_logs/${log_name}_run2.log
89+
# end_time=$(date +%s)
90+
# echo "Time elapsed: $((end_time - start_time))s"
91+
92+
# sleep 10
93+
94+
kill ${pid}
95+
kill ${hl_pid}
96+
#--backend openai-chat --endpoint "v1/chat/completions"
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"method": "HOOKS",
3+
"mode": "MEASURE",
4+
"observer": "maxabs",
5+
"whitelist": {
6+
"types": [],
7+
"names": []
8+
},
9+
"blocklist": {
10+
"types": [],
11+
"names": ["lm_head", "mlp\\.gate\\b"]
12+
},
13+
"quantize_weight": false,
14+
"dump_stats_path": "./nc_workspace_measure_kvache/inc_measure_output"
15+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"mode": "QUANTIZE",
3+
"observer": "maxabs",
4+
"scale_method": "ACT_MAXABS_HW_WEIGHTS_PCS_MAXABS_POW2",
5+
"scale_format": "const",
6+
"allowlist": {
7+
"types": [],
8+
"names": []
9+
},
10+
"blocklist": {
11+
"types": [],
12+
"names": [
13+
"lm_head",
14+
"mlp\\.gate\\b",
15+
"matmul_qk",
16+
"matmul_av",
17+
"batch2block_matmul",
18+
"block2batch_matmul"
19+
]
20+
},
21+
"dump_stats_path": "./nc_workspace_measure_kvache/inc_measure_output"
22+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"mode": "QUANTIZE",
3+
"observer": "maxabs",
4+
"scale_method": "maxabs_hw",
5+
"scale_format": "const",
6+
"allowlist": {
7+
"types": [],
8+
"names": []
9+
},
10+
"blocklist": {
11+
"types": [],
12+
"names": [
13+
"lm_head",
14+
"mlp\\.gate\\b",
15+
"matmul_qk",
16+
"matmul_av",
17+
"batch2block_matmul",
18+
"block2batch_matmul"
19+
]
20+
},
21+
"dump_stats_path": "./nc_workspace_measure_kvache/inc_measure_output"
22+
}

scripts/run_example_tp.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
parser.add_argument("--isl", type=int, default=1024, help="input sequence length.")
2626
parser.add_argument("--osl", type=int, default=1024, help="output sequence length.")
2727
parser.add_argument("--nprompts", type=int, default=4, help="The number of prompts.")
28+
parser.add_argument("--max_num_seqs", type=int, default=None, help="The max number of sequences.")
2829
parser.add_argument("--random", action="store_true", help="Randomly sample prompts.")
2930
parser.add_argument("--fp8_kv_cache", action="store_true", help="Use fp8 for kv cache.")
3031
args = parser.parse_args()
@@ -160,6 +161,70 @@ def sample_gsm8k_requests(
160161
tokenizer=tokenizer,
161162
do_random=args.random,
162163
)
164+
elif args.dataset == "pile":
165+
166+
def reset_seed(seed=42):
167+
import torch
168+
import random
169+
import numpy as np
170+
171+
torch.manual_seed(seed)
172+
np.random.seed(seed)
173+
random.seed(seed)
174+
175+
def get_prompt_token_ids(model_path, prompts, max_length=1024):
176+
from transformers import AutoTokenizer
177+
178+
tokenizer = AutoTokenizer.from_pretrained(model_path)
179+
prompt_token_ids = []
180+
for prompt in prompts:
181+
tokens = tokenizer(
182+
prompt,
183+
return_tensors="pt",
184+
truncation=True,
185+
max_length=max_length,
186+
)
187+
if len(tokens.input_ids[0]) < max_length:
188+
continue
189+
prompt_token_ids.append([x.item() for x in tokens.input_ids[0]])
190+
return prompt_token_ids
191+
192+
def get_pile_prompts(model_name, num_samples=512):
193+
from datasets import load_dataset
194+
from tqdm import tqdm
195+
import transformers
196+
197+
least_tokens = 1024
198+
seed = 42
199+
200+
reset_seed(seed)
201+
202+
dataset = load_dataset("NeelNanda/pile-10k", split="train")
203+
dataset = dataset.shuffle(seed=seed)
204+
205+
tokenizer = transformers.AutoTokenizer.from_pretrained(
206+
model_name, trust_remote_code=True
207+
)
208+
num_sample = 0
209+
samples_lst = []
210+
for data in tqdm(dataset):
211+
prompt = data["text"]
212+
tokens = tokenizer(prompt, return_tensors="pt")
213+
if len(tokens.input_ids[0]) < least_tokens:
214+
continue
215+
num_sample += 1
216+
samples_lst.append(prompt)
217+
if num_sample >= num_samples:
218+
break
219+
return samples_lst
220+
least_tokens = args.isl
221+
num_samples = args.nprompts
222+
prompts = get_pile_prompts(args.model, num_samples)
223+
prompt_token_ids = get_prompt_token_ids(
224+
args.model, prompts, least_tokens
225+
)
226+
print(f"Got {len(prompts)} prompts, length of first prompt: {len(prompt_token_ids[0])}.")
227+
gt = None
163228
else:
164229
prompts = [
165230
"Hello, my name is",
@@ -178,6 +243,8 @@ def sample_gsm8k_requests(
178243
param = {}
179244
if args.fp8_kv_cache:
180245
param["kv_cache_dtype"] = "fp8_inc"
246+
if args.max_num_seqs is not None:
247+
param["max_num_seqs"] = args.max_num_seqs
181248
if args.tp_size == 1:
182249
llm = LLM(
183250
model=model,
@@ -204,7 +271,12 @@ def sample_gsm8k_requests(
204271
# Generate texts from the prompts. The output is a list of RequestOutput objects
205272
# that contain the prompt, generated text, and other information.
206273
start = time.perf_counter()
207-
outputs = llm.generate(prompts, sampling_params)
274+
if args.dataset == "pile":
275+
outputs = llm.generate(
276+
prompts=None, sampling_params=sampling_params, prompt_token_ids=prompt_token_ids
277+
)
278+
else:
279+
outputs = llm.generate(prompts, sampling_params)
208280
end = time.perf_counter()
209281
# Print the outputs.
210282
print(f"e2e took {end - start} seconds")
@@ -218,4 +290,6 @@ def sample_gsm8k_requests(
218290
print(f"Generated text: {generated_text!r}")
219291
print(f"Ground truth: {gt_i!r}")
220292
print("====================================")
293+
if os.getenv("VLLM_REQUANT_FP8_INC", None) is not None:
294+
llm.llm_engine.model_executor.shutdown()
221295
del llm

scripts/run_lm_eval.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
parser.add_argument("--tokenizer", type=str, default=None, help="The model path.")
1717
parser.add_argument("--tp_size", type=int, default=8, help="Tensor Parallelism size.")
1818
parser.add_argument("--ep_size", type=int, default=8, help="Expert Parallelism size.")
19-
parser.add_argument("-l", "--limit", type=int, default=64, help="test request counts.")
19+
parser.add_argument("-l", "--limit", type=int, default=None, help="test request counts.")
2020
parser.add_argument("--batch_size", type=int, default=1, help="The batch size.")
21+
parser.add_argument("--fp8_kv_cache", action="store_true", help="Use fp8 for kv cache.")
2122
args = parser.parse_args()
2223

2324
os.environ["VLLM_SKIP_WARMUP"] = "true"
@@ -44,6 +45,9 @@
4445
model = args.model
4546
if args.tokenizer is None:
4647
args.tokenizer = model
48+
param = {}
49+
if args.fp8_kv_cache:
50+
param["kv_cache_dtype"] = "fp8_inc"
4751
if args.tp_size == 1:
4852
llm = VLLM(
4953
pretrained=model,
@@ -65,17 +69,25 @@
6569
dtype="bfloat16",
6670
gpu_memory_utilization=0.8,
6771
batch_size=args.batch_size,
72+
**param,
6873
)
6974

7075

7176
# Run the evaluation; you can adjust num_fewshot and batch_size as needed.
7277
start = time.perf_counter()
7378
if args.task == "gsm8k":
74-
results = simple_evaluate(model=llm, tasks=["gsm8k"], num_fewshot=5, batch_size=8, limit=args.limit)
79+
from lm_eval.utils import make_table
80+
81+
results = simple_evaluate(
82+
model=llm,
83+
tasks=["gsm8k"],
84+
limit=args.limit,
85+
)
7586
end = time.perf_counter()
7687
e2e = end - start
88+
print(make_table(results))
7789
# save as json
78-
with open(f"gsm8k_ep{args.ep_size}_result_samples_limit{args.limit}.jsonl", "w") as f:
90+
with open(f"gsm8k_ep{args.ep_size}_result_samples_limit{str(args.limit)}.jsonl", "w") as f:
7991
json.dump(results['results'], f)
8092
json.dump({"e2e time(secs)": e2e}, f)
8193
f.write("\n")
@@ -86,7 +98,7 @@
8698
results = simple_evaluate(model=llm, tasks=["hellaswag"], num_fewshot=0, batch_size=8, limit=args.limit)
8799
end = time.perf_counter()
88100
e2e = end - start
89-
with open(f"hallaswag_ep{args.ep_size}_result_samples_limit{args.limit}.jsonl", "w") as f:
101+
with open(f"hallaswag_ep{args.ep_size}_result_samples_limit{str(args.limit)}.jsonl", "w") as f:
90102
json.dump(results['results'], f)
91103
json.dump({"e2e time(secs)": e2e}, f)
92104
f.write("\n")

0 commit comments

Comments
 (0)