Skip to content

[BERT] Add support for sdpa #28802

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
ed6efd0
Adding SDPA support for BERT
hackyon Feb 14, 2024
0f34d58
Using the proper input name for testing model input in inference()
hackyon Feb 15, 2024
c85acbd
Adding documentation for SDPA in BERT model page
hackyon Feb 15, 2024
5ce07b3
Use the stable link for the documentation
hackyon Feb 15, 2024
288cc1d
Adding a gate to only call .contiguous() for torch < 2.2.0
hackyon Feb 16, 2024
2afd61f
Additions and fixes to the documentation
hackyon Feb 16, 2024
fa8b5ad
Minor updates to documentation
hackyon Feb 17, 2024
05d5c4e
Adding extra requirements needed for the contiguous() bug
hackyon Feb 18, 2024
95ec569
Adding "Adapted from" in plcae of the "Copied from"
hackyon Feb 19, 2024
a07fd89
Add benchmark speedup tables to the documentation
hackyon Feb 21, 2024
64334c1
Minor fixes to the documentation
hackyon Feb 21, 2024
6a7376d
Use ClapText as a replacemenet for Bert in the Copied-From
hackyon Feb 21, 2024
5ddb6e1
Some more fixes for the fix-copies references
hackyon Feb 21, 2024
35577eb
Overriding the test_eager_matches_sdpa_generate in bert tests to not …
hackyon Feb 22, 2024
0e62fe0
Undo changes to separate test
hackyon Feb 23, 2024
5c64480
Refactored SDPA self attention code for KV projections
hackyon Mar 6, 2024
9a9bb9b
Change use_sdpa to attn_implementation
hackyon Mar 7, 2024
1a0af20
Merge remote-tracking branch 'upstream/main' into sdpa-bert
hackyon Mar 19, 2024
0965399
Merge remote-tracking branch 'upstream/main' into sdpa-bert
hackyon Apr 6, 2024
b4813a0
Merge remote-tracking branch 'upstream/main' into sdpa-bert
hackyon Apr 11, 2024
e312cd1
Merge remote-tracking branch 'upstream/main' into sdpa-bert
hackyon Apr 22, 2024
66a24c1
Fix test_sdpa_can_dispatch_on_flash by preparing input (required for …
hackyon Apr 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions docs/source/en/model_doc/bert.md
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice thanks for the detailed benchmark! 🤗

Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,53 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o

- The model must predict the original sentence, but has a second objective: inputs are two sentences A and B (with a separation token in between). With probability 50%, the sentences are consecutive in the corpus, in the remaining 50% they are not related. The model has to predict if the sentences are consecutive or not.

### Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```
from transformers import BertModel

model = BertModel.from_pretrained("bert-base-uncased", torch_dtype=torch.float16, attn_implementation="sdpa")
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

On a local benchmark (A100-80GB, CPUx12, RAM 96.6GB, PyTorch 2.2.0, OS Ubuntu 22.04) with `float16`, we saw the
following speedups during training and inference.

#### Training

|batch_size|seq_len|Time per batch (eager - s)|Time per batch (sdpa - s)|Speedup (%)|Eager peak mem (MB)|sdpa peak mem (MB)|Mem saving (%)|
|----------|-------|--------------------------|-------------------------|-----------|-------------------|------------------|--------------|
|4 |256 |0.023 |0.017 |35.472 |939.213 |764.834 |22.800 |
|4 |512 |0.023 |0.018 |23.687 |1970.447 |1227.162 |60.569 |
|8 |256 |0.023 |0.018 |23.491 |1594.295 |1226.114 |30.028 |
|8 |512 |0.035 |0.025 |43.058 |3629.401 |2134.262 |70.054 |
|16 |256 |0.030 |0.024 |25.583 |2874.426 |2134.262 |34.680 |
|16 |512 |0.064 |0.044 |46.223 |6964.659 |3961.013 |75.830 |

#### Inference

|batch_size|seq_len|Per token latency eager (ms)|Per token latency SDPA (ms)|Speedup (%)|Mem eager (MB)|Mem BT (MB)|Mem saved (%)|
|----------|-------|----------------------------|---------------------------|-----------|--------------|-----------|-------------|
|1 |128 |5.736 |4.987 |15.022 |282.661 |282.924 |-0.093 |
|1 |256 |5.689 |4.945 |15.055 |298.686 |298.948 |-0.088 |
|2 |128 |6.154 |4.982 |23.521 |314.523 |314.785 |-0.083 |
|2 |256 |6.201 |4.949 |25.303 |347.546 |347.033 |0.148 |
|4 |128 |6.049 |4.987 |21.305 |378.895 |379.301 |-0.107 |
|4 |256 |6.285 |5.364 |17.166 |443.209 |444.382 |-0.264 |



## Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with BERT. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
Expand Down
12 changes: 9 additions & 3 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,11 @@ FlashAttention is more memory efficient, meaning you can train on much larger se

## PyTorch scaled dot product attention

PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available.
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

For now, Transformers supports SDPA inference and training for the following architectures:
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
Expand All @@ -210,6 +211,13 @@ FlashAttention can only be used for models with the `fp16` or `bf16` torch type,

</Tip>

<Tip>

SDPA does not support certain sets of attention parameters, such as `head_mask` and `output_attentions=True`.
In that case, you should see a warning message and we will fall back to the (slower) eager implementation.

</Tip>

By default, SDPA selects the most performant kernel available but you can check whether a backend is available in a given setting (hardware, problem size) with [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:

```diff
Expand All @@ -218,8 +226,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16).to("cuda")
# convert the model to BetterTransformer
model.to_bettertransformer()

input_text = "Hello my dog is cute and"
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
Expand Down
6 changes: 2 additions & 4 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,10 +445,8 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype,
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)

if torch.all(mask == 1):
if is_tracing:
pass
elif tgt_len == 1:
if not is_tracing and torch.all(mask == 1):
if tgt_len == 1:
# For query_length == 1, causal attention and bi-directional attention are the same.
return None
elif key_value_length == tgt_len:
Expand Down
11 changes: 9 additions & 2 deletions src/transformers/models/align/modeling_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,11 +883,18 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
return hidden_states


# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->AlignText
ALIGN_TEXT_SELF_ATTENTION_CLASSES = {
"eager": AlignTextSelfAttention,
}


# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->AlignText,BERT->ALIGN_TEXT
class AlignTextAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = AlignTextSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = ALIGN_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = AlignTextSelfOutput(config)
self.pruned_heads = set()

Expand Down
15 changes: 11 additions & 4 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,11 +434,18 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
return hidden_states


# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->AltRoberta
ALT_ROBERTA_SELF_ATTENTION_CLASSES = {
"eager": AltRobertaSelfAttention,
}


# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->AltRoberta,ROBERTA->ALT_ROBERTA
class AltRobertaAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = AltRobertaSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = ALT_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = AltRobertaSelfOutput(config)
self.pruned_heads = set()

Expand Down Expand Up @@ -1205,7 +1212,7 @@ class AltRobertaModel(AltCLIPPreTrainedModel):

config_class = AltCLIPTextConfig

# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->AltRoberta
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->AltRoberta
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
Expand All @@ -1232,7 +1239,7 @@ class PreTrainedModel
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)

# Copied from transformers.models.bert.modeling_bert.BertModel.forward
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
Expand Down
Loading
Loading