Skip to content

Fix cuda kernel for seq_len < 8192 #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 8192 );
if (softmax_elements == 0) {
return;
} else {
Expand Down Expand Up @@ -416,6 +416,14 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 12: // 4096
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 13: // 8192
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
Expand All @@ -432,7 +440,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 8192 );
if (softmax_elements == 0) {
return;
} else {
Expand Down Expand Up @@ -507,6 +515,14 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 12: // 4096
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 13: // 8192
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ torch::Tensor fwd_cuda(
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
TORCH_INTERNAL_ASSERT(seq_len <= 8192);

// Output
auto act_options = input.options().requires_grad(false);
Expand Down
4 changes: 4 additions & 0 deletions megatron/fused_kernels/tests/test_fused_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import torch
from torch.nn import LayerNorm

import sys
# add to path
sys.path.append("/home/nouamane/projects/Megatron-DeepSpeed/")
import megatron
from megatron.model.enums import AttnMaskType
from megatron.model.fused_layer_norm import MixedFusedLayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
Expand Down
2 changes: 1 addition & 1 deletion megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _compile_dependencies():
args.micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
custom_kernel_constraint = seq_len > 16 and seq_len <=8192 and \
seq_len % 4 == 0 and attn_batch_size % 4 == 0
# Print a warning.
if not ((args.fp16 or args.bf16) and
Expand Down
4 changes: 2 additions & 2 deletions megatron/model/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@ def is_kernel_available(self, mask, b, np, sq, sk):
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and 16 < sk <= 4096 # sk must be 16 ~ 4096
and 16 < sk <= 8192 # sk must be 16 ~ 8192
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 4096:
if 0 <= sk <= 8192:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)

if self.attn_mask_type == AttnMaskType.causal:
Expand Down