You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
### Description
This PR updates the LLaMA-2 attention fusions by adding the following.
- Loading the PyTorch model from Hugging Face with the `LlamaAttention`
class before exporting
- Updating the attention mask pattern matching to support another case
This PR also fixes [this
issue](#19040).
### Motivation and Context
Recent changes to Hugging Face's `transformers` library break the
existing pattern matching. Since the attention fusions aim to change the
graph from `LayerNorm Op --> Set of Attention Nodes --> LayerNorm Op` to
`LayerNorm Op --> Attention Op --> LayerNorm Op` per layer, ultimately
it does not matter what nodes comprise the `Set of Attention Nodes`
because they will all be removed and replaced by the `Attention Op` in
the end.
Therefore, it does not matter whether the `LlamaAttention` class or a
different attention class is used to load the PyTorch model before
exporting because the expected graphs after the attention fusions will
look identical no matter the attention class chosen. By loading the
PyTorch model with the `LlamaAttention` class instead of other attention
classes (e.g. `LlamaFlashAttention2` or `LlamaSdpaAttention`) and then
exporting it to ONNX, the existing pattern matching will continue to
work.
To make this option compatible with [Hugging Face's Optimum](https://github.com/huggingface/optimum), you will need to create `config.json` and `generation_config.json` for your model and store them in the same directory as your ONNX models. For example, you can find those JSON files for LLaMA-2 7B on Hugging Face [here](https://huggingface.co/meta-llama/Llama-2-7b-hf).
44
44
45
-
As indicated in `requirements.txt`, you will also need to install Optimum from source. Once installed, you will need to modify `ORTModelForCausalLM.forward` in `optimum/optimum/onnxruntime/modeling_decoder.py` as follows:
46
-
47
-
```
48
-
# Before
49
-
if self.use_cache:
50
-
if past_key_values is not None:
51
-
input_ids = input_ids[:, -1:]
52
-
# Flatten the past_key_values (no need to flatten for models using multi-query attn)
53
-
54
-
55
-
# After
56
-
if self.use_cache:
57
-
if past_key_values is not None:
58
-
input_ids = input_ids[:, -1:] if past_key_values[0][0].shape[2] != 0 else input_ids
59
-
# Flatten the past_key_values (no need to flatten for models using multi-query attn)
60
-
```
61
-
62
45
### Option 2: from [Microsoft's custom export](https://github.com/microsoft/Llama-2-Onnx)
63
46
64
47
Please follow the [README instructions](https://github.com/microsoft/Llama-2-Onnx#before-you-start) in the custom export of LLaMA-2.
@@ -254,7 +237,7 @@ Here are some examples of how you can benchmark LLaMA-2.
0 commit comments