Skip to content

Commit a3ecb63

Browse files
Update LLaMA attention fusions (#19200)
### Description This PR updates the LLaMA-2 attention fusions by adding the following. - Loading the PyTorch model from Hugging Face with the `LlamaAttention` class before exporting - Updating the attention mask pattern matching to support another case This PR also fixes [this issue](#19040). ### Motivation and Context Recent changes to Hugging Face's `transformers` library break the existing pattern matching. Since the attention fusions aim to change the graph from `LayerNorm Op --> Set of Attention Nodes --> LayerNorm Op` to `LayerNorm Op --> Attention Op --> LayerNorm Op` per layer, ultimately it does not matter what nodes comprise the `Set of Attention Nodes` because they will all be removed and replaced by the `Attention Op` in the end. Therefore, it does not matter whether the `LlamaAttention` class or a different attention class is used to load the PyTorch model before exporting because the expected graphs after the attention fusions will look identical no matter the attention class chosen. By loading the PyTorch model with the `LlamaAttention` class instead of other attention classes (e.g. `LlamaFlashAttention2` or `LlamaSdpaAttention`) and then exporting it to ONNX, the existing pattern matching will continue to work.
1 parent eaf047c commit a3ecb63

File tree

5 files changed

+46
-25
lines changed

5 files changed

+46
-25
lines changed

onnxruntime/python/tools/transformers/fusion_rotary_attention.py

+10
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
539539

540540
# attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask
541541
# attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask
542+
# attn_mask_nodes_5, attn_mask_nodes_6 are for LLaMA-2 Microsoft's model for the DML EP
543+
# attn_mask_nodes_7 is for LLaMA-2 Hugging Face's changes to the attention mask
542544
attn_mask, add_qk_str = "", ""
543545
attn_mask_nodes_1 = self.model.match_parent_path(
544546
add_qk,
@@ -570,6 +572,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
570572
["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
571573
[1, 0, 2, 1, 0, 0, 0],
572574
)
575+
attn_mask_nodes_7 = self.model.match_parent_path(
576+
add_qk,
577+
["Where", "Cast", "Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
578+
[1, 0, 0, 0, 0, 1, 0, 0, 0],
579+
)
573580
if attn_mask_nodes_1 is not None:
574581
_, slice_mask_1, slice_mask_2 = attn_mask_nodes_1
575582
attn_mask = slice_mask_1.output[0]
@@ -588,6 +595,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
588595
elif attn_mask_nodes_6 is not None:
589596
# The mask has already been reshaped to (B,N,S,T)
590597
add_qk_str = attn_mask_nodes_6[0].output[0]
598+
elif attn_mask_nodes_7 is not None:
599+
# Reshape from (B,1,S,T) to (B,N,S,T)
600+
add_qk_str = self.reshape_add_qk(attn_mask_nodes_7[0].output[0])
591601
else:
592602
logger.debug("fuse_rotary_attention: failed to match attention mask nodes")
593603
return

onnxruntime/python/tools/transformers/models/llama/README.md

+7-24
Original file line numberDiff line numberDiff line change
@@ -42,23 +42,6 @@ $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama
4242

4343
To make this option compatible with [Hugging Face's Optimum](https://github.com/huggingface/optimum), you will need to create `config.json` and `generation_config.json` for your model and store them in the same directory as your ONNX models. For example, you can find those JSON files for LLaMA-2 7B on Hugging Face [here](https://huggingface.co/meta-llama/Llama-2-7b-hf).
4444

45-
As indicated in `requirements.txt`, you will also need to install Optimum from source. Once installed, you will need to modify `ORTModelForCausalLM.forward` in `optimum/optimum/onnxruntime/modeling_decoder.py` as follows:
46-
47-
```
48-
# Before
49-
if self.use_cache:
50-
if past_key_values is not None:
51-
input_ids = input_ids[:, -1:]
52-
# Flatten the past_key_values (no need to flatten for models using multi-query attn)
53-
54-
55-
# After
56-
if self.use_cache:
57-
if past_key_values is not None:
58-
input_ids = input_ids[:, -1:] if past_key_values[0][0].shape[2] != 0 else input_ids
59-
# Flatten the past_key_values (no need to flatten for models using multi-query attn)
60-
```
61-
6245
### Option 2: from [Microsoft's custom export](https://github.com/microsoft/Llama-2-Onnx)
6346

6447
Please follow the [README instructions](https://github.com/microsoft/Llama-2-Onnx#before-you-start) in the custom export of LLaMA-2.
@@ -254,7 +237,7 @@ Here are some examples of how you can benchmark LLaMA-2.
254237

255238
1. PyTorch without `torch.compile`, FP32
256239
```
257-
python3 -m models.llama.benchmark \
240+
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
258241
--benchmark-type hf-pt-eager \
259242
--model-name meta-llama/Llama-2-7b-hf \
260243
--precision fp32 \
@@ -266,7 +249,7 @@ python3 -m models.llama.benchmark \
266249

267250
2. PyTorch with `torch.compile`, FP16
268251
```
269-
python3 -m models.llama.benchmark \
252+
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
270253
--benchmark-type hf-pt-compile \
271254
--model-name meta-llama/Llama-2-7b-hf \
272255
--precision fp16 \
@@ -278,7 +261,7 @@ python3 -m models.llama.benchmark \
278261

279262
3. Optimum + ONNX Runtime, FP32, export via Optimum or convert_to_onnx
280263
```
281-
python3 -m models.llama.benchmark \
264+
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
282265
--benchmark-type hf-ort \
283266
--hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \
284267
--model-name meta-llama/Llama-2-7b-hf \
@@ -291,7 +274,7 @@ python3 -m models.llama.benchmark \
291274

292275
4. Optimum + ONNX Runtime, FP16, export via Optimum or convert_to_onnx
293276
```
294-
python3 -m models.llama.benchmark \
277+
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
295278
--benchmark-type hf-ort \
296279
--hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \
297280
--model-name meta-llama/Llama-2-7b-hf \
@@ -304,7 +287,7 @@ python3 -m models.llama.benchmark \
304287

305288
5. ONNX Runtime, FP32, Microsoft custom export
306289
```
307-
python3 -m models.llama.benchmark \
290+
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
308291
--benchmark-type ort-msft \
309292
--ort-model-path ./llama-2-onnx/7B_float32/ONNX/LlamaV2_7B_float32.onnx \
310293
--model-name meta-llama/Llama-2-7b-hf \
@@ -316,7 +299,7 @@ python3 -m models.llama.benchmark \
316299

317300
6. ONNX Runtime, FP16, Microsoft custom export
318301
```
319-
python3 -m models.llama.benchmark \
302+
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
320303
--benchmark-type ort-msft \
321304
--ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \
322305
--model-name meta-llama/Llama-2-7b-hf \
@@ -367,7 +350,7 @@ You can profile a variant by adding the `--profile` flag and providing one batch
367350
### Benchmark All
368351
You can use `benchmark_all.py` to benchmark across various options and automatically store the results in a CSV file. Here is an example.
369352
```
370-
python3 -m models.llama.benchmark_all \
353+
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_all \
371354
--hf-pt-eager \
372355
--hf-pt-compile \
373356
--hf-ort-dir-path ./llama2-7b-fp16/ \

onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py

+27
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import logging
55
import os
66
import shutil
7+
import subprocess
8+
import sys
79
from itertools import chain
810

911
import onnx
@@ -408,6 +410,31 @@ def optimize_export(config: AutoConfig, input_path: str, output_path: str, remov
408410
only_onnxruntime=False,
409411
)
410412
model_opt.save_model_to_file(output_path, use_external_data_format=True)
413+
414+
# Run symbolic shape inference on optimized model to avoid shape errors during runtime
415+
# Ex: Before attention fusion, RotaryEmbedding assumes a 4D input and produces a 4D output.
416+
# After attention fusion, RotaryEmbedding expects a 3D input and produces a 3D output.
417+
wheel_cmd = [sys.executable, "-m", "onnxruntime.tools.symbolic_shape_infer"]
418+
source_cmd = [sys.executable, "../symbolic_shape_infer.py"]
419+
symbolic_shape_infer_args = [
420+
"--input",
421+
output_path,
422+
"--output",
423+
output_path,
424+
"--auto_merge",
425+
"--save_as_external_data",
426+
"--all_tensors_to_one_file",
427+
"--external_data_location",
428+
os.path.basename(output_path) + ".data",
429+
]
430+
431+
file_path = os.path.dirname(__file__)
432+
if os.path.exists(os.path.join(file_path, "../../../tools/symbolic_shape_infer.py")):
433+
main_cmd = wheel_cmd
434+
else:
435+
main_cmd = source_cmd
436+
subprocess.run(main_cmd + symbolic_shape_infer_args) # noqa: PLW1510
437+
411438
logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!")
412439
if remove_model:
413440
remove_existing_model(input_path)

onnxruntime/python/tools/transformers/models/llama/llama_torch.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32,
2121
if i == rank % (world_size):
2222
l_config = AutoConfig.from_pretrained(location, use_auth_token=use_auth_token, cache_dir=args.cache_dir)
2323
l_config.use_cache = True
24+
l_config._attn_implementation = "eager" # "eager" uses LlamaAttention for attention layer
2425
llama = AutoModelForCausalLM.from_pretrained(
2526
location,
2627
use_auth_token=use_auth_token,

onnxruntime/python/tools/transformers/models/llama/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
git+https://github.com/huggingface/optimum.git
1+
optimum>=1.14.1
22
transformers>=4.33.2
33
torch>=2.2.0.dev20230920
44
onnx>=1.14.0

0 commit comments

Comments
 (0)