Skip to content

Commit 3ea9099

Browse files
committed
Add varlen MHA fp16 slen=384 kernels for sm_86
1. add varlen mha fp16 slen=384 kernel for sm_86 2. referesh all sm_86 kernels now use NVCC -gencode=arch=compute_86,code=\"sm_86\" 3. use unfused kernel for fixed len s=384 fp16 Signed-off-by: Rajeev Rao <[email protected]>
1 parent 0fa021a commit 3ea9099

19 files changed

+53684
-88
lines changed

Diff for: plugin/bertQKVToContextPlugin/fused_multihead_attention.h

+5
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ extern unsigned char fused_multihead_attention_int8_384_64_kernel_sm80_cu_o[];
9898
extern unsigned char fused_multihead_attention_int8_128_64_kernel_sm80_cu_o[];
9999
extern unsigned char fused_multihead_attention_fp16_128_64_kernel_sm80_cu_o[];
100100
extern unsigned char fused_multihead_attention_fp16_384_64_kernel_sm80_cu_o[];
101+
extern unsigned char fused_multihead_attention_fp16_384_64_kernel_sm86_cu_o[];
101102

102103
extern unsigned int fused_multihead_attention_fp16_64_64_kernel_sm75_cu_o_len;
103104
extern unsigned int fused_multihead_attention_fp16_96_64_kernel_sm75_cu_o_len;
@@ -111,6 +112,7 @@ extern unsigned int fused_multihead_attention_int8_384_64_kernel_sm80_cu_o_len;
111112
extern unsigned int fused_multihead_attention_int8_128_64_kernel_sm80_cu_o_len;
112113
extern unsigned int fused_multihead_attention_fp16_128_64_kernel_sm80_cu_o_len;
113114
extern unsigned int fused_multihead_attention_fp16_384_64_kernel_sm80_cu_o_len;
115+
extern unsigned int fused_multihead_attention_fp16_384_64_kernel_sm86_cu_o_len;
114116

115117
static const struct FusedMultiHeadAttentionKernelMetaInfoV1
116118
{
@@ -175,6 +177,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV1
175177
{DATA_TYPE_FP16, 128, 64, kSM_86, fused_multihead_attention_fp16_128_64_kernel_sm80_cu_o,
176178
fused_multihead_attention_fp16_128_64_kernel_sm80_cu_o_len, "fused_multihead_attention_fp16_128_64_kernel_sm80",
177179
49152, 128},
180+
{DATA_TYPE_FP16, 384, 64, kSM_86, fused_multihead_attention_fp16_384_64_kernel_sm86_cu_o,
181+
fused_multihead_attention_fp16_384_64_kernel_sm86_cu_o_len, "fused_multihead_attention_fp16_384_64_kernel_sm80",
182+
65536, 256},
178183
{DATA_TYPE_INT8, 128, 64, kSM_86, fused_multihead_attention_int8_128_64_kernel_sm80_cu_o,
179184
fused_multihead_attention_int8_128_64_kernel_sm80_cu_o_len, "fused_multihead_attention_int8_128_64_kernel_sm80",
180185
24576, 128},

Diff for: plugin/bertQKVToContextPlugin/fused_multihead_attention_fp16_384_64_kernel.sm86.cpp

+1,661
Large diffs are not rendered by default.

Diff for: plugin/bertQKVToContextPlugin/fused_multihead_attention_v2.h

+69-45
Original file line numberDiff line numberDiff line change
@@ -111,49 +111,67 @@ struct Fused_multihead_attention_params_v2
111111
////////////////////////////////////////////////////////////////////////////////////////////////////
112112
extern unsigned char fused_multihead_attention_v2_fp16_128_64_kernel_sm75_cubin[];
113113
extern unsigned char fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin[];
114+
extern unsigned char fused_multihead_attention_v2_fp16_128_64_kernel_sm86_cubin[];
114115
extern unsigned char fused_multihead_attention_v2_fp16_256_64_kernel_sm75_cubin[];
115116
extern unsigned char fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin[];
117+
extern unsigned char fused_multihead_attention_v2_fp16_256_64_kernel_sm86_cubin[];
116118
extern unsigned char fused_multihead_attention_v2_fp16_384_64_kernel_sm75_cubin[];
117119
extern unsigned char fused_multihead_attention_v2_fp16_384_64_kernel_sm80_cubin[];
120+
extern unsigned char fused_multihead_attention_v2_fp16_384_64_kernel_sm86_cubin[];
118121
extern unsigned char fused_multihead_attention_v2_fp16_64_64_kernel_sm75_cubin[];
119122
extern unsigned char fused_multihead_attention_v2_fp16_64_64_kernel_sm80_cubin[];
123+
extern unsigned char fused_multihead_attention_v2_fp16_64_64_kernel_sm86_cubin[];
120124
extern unsigned char fused_multihead_attention_v2_fp16_96_64_kernel_sm75_cubin[];
121125
extern unsigned char fused_multihead_attention_v2_fp16_96_64_kernel_sm80_cubin[];
126+
extern unsigned char fused_multihead_attention_v2_fp16_96_64_kernel_sm86_cubin[];
122127
extern unsigned char fused_multihead_attention_v2_int8_128_64_kernel_cubin[];
123128
extern unsigned char fused_multihead_attention_v2_int8_128_64_kernel_sm75_cubin[];
124129
extern unsigned char fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin[];
130+
extern unsigned char fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin[];
125131
extern unsigned char fused_multihead_attention_v2_int8_192_64_kernel_cubin[];
126132
extern unsigned char fused_multihead_attention_v2_int8_192_64_kernel_sm75_cubin[];
127133
extern unsigned char fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin[];
134+
extern unsigned char fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin[];
128135
extern unsigned char fused_multihead_attention_v2_int8_256_64_kernel_cubin[];
129136
extern unsigned char fused_multihead_attention_v2_int8_256_64_kernel_sm75_cubin[];
130137
extern unsigned char fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin[];
138+
extern unsigned char fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin[];
131139
extern unsigned char fused_multihead_attention_v2_int8_384_64_kernel_cubin[];
132140
extern unsigned char fused_multihead_attention_v2_int8_384_64_kernel_sm75_cubin[];
133141
extern unsigned char fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin[];
142+
extern unsigned char fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin[];
134143

135144
extern unsigned int fused_multihead_attention_v2_fp16_128_64_kernel_sm75_cubin_len;
136145
extern unsigned int fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin_len;
146+
extern unsigned int fused_multihead_attention_v2_fp16_128_64_kernel_sm86_cubin_len;
137147
extern unsigned int fused_multihead_attention_v2_fp16_256_64_kernel_sm75_cubin_len;
138148
extern unsigned int fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin_len;
149+
extern unsigned int fused_multihead_attention_v2_fp16_256_64_kernel_sm86_cubin_len;
139150
extern unsigned int fused_multihead_attention_v2_fp16_384_64_kernel_sm75_cubin_len;
140151
extern unsigned int fused_multihead_attention_v2_fp16_384_64_kernel_sm80_cubin_len;
152+
extern unsigned int fused_multihead_attention_v2_fp16_384_64_kernel_sm86_cubin_len;
141153
extern unsigned int fused_multihead_attention_v2_fp16_64_64_kernel_sm75_cubin_len;
142154
extern unsigned int fused_multihead_attention_v2_fp16_64_64_kernel_sm80_cubin_len;
155+
extern unsigned int fused_multihead_attention_v2_fp16_64_64_kernel_sm86_cubin_len;
143156
extern unsigned int fused_multihead_attention_v2_fp16_96_64_kernel_sm75_cubin_len;
144157
extern unsigned int fused_multihead_attention_v2_fp16_96_64_kernel_sm80_cubin_len;
158+
extern unsigned int fused_multihead_attention_v2_fp16_96_64_kernel_sm86_cubin_len;
145159
extern unsigned int fused_multihead_attention_v2_int8_128_64_kernel_cubin_len;
146160
extern unsigned int fused_multihead_attention_v2_int8_128_64_kernel_sm75_cubin_len;
147161
extern unsigned int fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin_len;
162+
extern unsigned int fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin_len;
148163
extern unsigned int fused_multihead_attention_v2_int8_192_64_kernel_cubin_len;
149164
extern unsigned int fused_multihead_attention_v2_int8_192_64_kernel_sm75_cubin_len;
150165
extern unsigned int fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin_len;
166+
extern unsigned int fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin_len;
151167
extern unsigned int fused_multihead_attention_v2_int8_256_64_kernel_cubin_len;
152168
extern unsigned int fused_multihead_attention_v2_int8_256_64_kernel_sm75_cubin_len;
153169
extern unsigned int fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin_len;
170+
extern unsigned int fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin_len;
154171
extern unsigned int fused_multihead_attention_v2_int8_384_64_kernel_cubin_len;
155172
extern unsigned int fused_multihead_attention_v2_int8_384_64_kernel_sm75_cubin_len;
156173
extern unsigned int fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin_len;
174+
extern unsigned int fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin_len;
157175

158176
static const struct FusedMultiHeadAttentionKernelMetaInfoV2
159177
{
@@ -348,72 +366,78 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
348366

349367
// GA10x
350368
// Note: For GA10X keep only kernels whose sharedMemBytes < 100KiB
351-
{DATA_TYPE_FP16, 64, 64, kSM_86, fused_multihead_attention_v2_fp16_64_64_kernel_sm80_cubin,
352-
fused_multihead_attention_v2_fp16_64_64_kernel_sm80_cubin_len,
369+
{DATA_TYPE_FP16, 64, 64, kSM_86, fused_multihead_attention_v2_fp16_64_64_kernel_sm86_cubin,
370+
fused_multihead_attention_v2_fp16_64_64_kernel_sm86_cubin_len,
353371
"fused_multihead_attention_v2_fp16_64_64_kernel_sm80", 32768, 128, 0, false},
354-
{DATA_TYPE_FP16, 96, 64, kSM_86, fused_multihead_attention_v2_fp16_96_64_kernel_sm80_cubin,
355-
fused_multihead_attention_v2_fp16_96_64_kernel_sm80_cubin_len,
372+
{DATA_TYPE_FP16, 96, 64, kSM_86, fused_multihead_attention_v2_fp16_96_64_kernel_sm86_cubin,
373+
fused_multihead_attention_v2_fp16_96_64_kernel_sm86_cubin_len,
356374
"fused_multihead_attention_v2_fp16_96_64_kernel_sm80", 49152, 128, 0, false},
357-
{DATA_TYPE_FP16, 128, 64, kSM_86, fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin,
358-
fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin_len,
375+
{DATA_TYPE_FP16, 128, 64, kSM_86, fused_multihead_attention_v2_fp16_128_64_kernel_sm86_cubin,
376+
fused_multihead_attention_v2_fp16_128_64_kernel_sm86_cubin_len,
359377
"fused_multihead_attention_v2_fp16_128_64_kernel_sm80_noloop", 40960, 128, 32, false},
360-
{DATA_TYPE_FP16, 128, 64, kSM_86, fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin,
361-
fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin_len,
378+
{DATA_TYPE_FP16, 128, 64, kSM_86, fused_multihead_attention_v2_fp16_128_64_kernel_sm86_cubin,
379+
fused_multihead_attention_v2_fp16_128_64_kernel_sm86_cubin_len,
362380
"fused_multihead_attention_v2_fp16_128_64_kernel_sm80", 65536, 128, 0, false},
363-
{DATA_TYPE_FP16, 256, 64, kSM_86, fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin,
364-
fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin_len,
381+
{DATA_TYPE_FP16, 256, 64, kSM_86, fused_multihead_attention_v2_fp16_256_64_kernel_sm86_cubin,
382+
fused_multihead_attention_v2_fp16_256_64_kernel_sm86_cubin_len,
365383
"fused_multihead_attention_v2_fp16_256_64_kernel_sm80_noloop", 73728, 128, 32, false},
366-
{DATA_TYPE_FP16, 256, 64, kSM_86, fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin,
367-
fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin_len,
384+
{DATA_TYPE_FP16, 256, 64, kSM_86, fused_multihead_attention_v2_fp16_256_64_kernel_sm86_cubin,
385+
fused_multihead_attention_v2_fp16_256_64_kernel_sm86_cubin_len,
368386
"fused_multihead_attention_v2_fp16_256_64_kernel_sm80", 73728, 128, 0, false},
369-
370-
{DATA_TYPE_INT8, 128, 64, kSM_86, fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin,
371-
fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin_len,
387+
{DATA_TYPE_FP16, 384, 64, kSM_86, fused_multihead_attention_v2_fp16_384_64_kernel_sm86_cubin,
388+
fused_multihead_attention_v2_fp16_384_64_kernel_sm86_cubin_len,
389+
"fused_multihead_attention_v2_fp16_384_64_kernel_sm80_noloop", 65536, 256, 48, false},
390+
{DATA_TYPE_FP16, 384, 64, kSM_86, fused_multihead_attention_v2_fp16_384_64_kernel_sm86_cubin,
391+
fused_multihead_attention_v2_fp16_384_64_kernel_sm86_cubin_len,
392+
"fused_multihead_attention_v2_fp16_384_64_kernel_sm80", 65536, 256, 0, false},
393+
394+
{DATA_TYPE_INT8, 128, 64, kSM_86, fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin,
395+
fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin_len,
372396
"fused_multihead_attention_v2_int8_128_64_kernel_sm80_interleaved_noloop", 20480, 128, 16, true},
373-
{DATA_TYPE_INT8, 128, 64, kSM_86, fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin,
374-
fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin_len,
397+
{DATA_TYPE_INT8, 128, 64, kSM_86, fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin,
398+
fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin_len,
375399
"fused_multihead_attention_v2_int8_128_64_kernel_sm80_noloop", 20480, 128, 16, false},
376-
{DATA_TYPE_INT8, 128, 64, kSM_86, fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin,
377-
fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin_len,
400+
{DATA_TYPE_INT8, 128, 64, kSM_86, fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin,
401+
fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin_len,
378402
"fused_multihead_attention_v2_int8_128_64_kernel_sm80_interleaved", 24576, 128, 0, true},
379-
{DATA_TYPE_INT8, 128, 64, kSM_86, fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin,
380-
fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin_len,
403+
{DATA_TYPE_INT8, 128, 64, kSM_86, fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin,
404+
fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin_len,
381405
"fused_multihead_attention_v2_int8_128_64_kernel_sm80", 32768, 128, 0, false},
382-
{DATA_TYPE_INT8, 192, 64, kSM_86, fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin,
383-
fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin_len,
406+
{DATA_TYPE_INT8, 192, 64, kSM_86, fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin,
407+
fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin_len,
384408
"fused_multihead_attention_v2_int8_192_64_kernel_sm80_interleaved_noloop", 28672, 128, 32, true},
385-
{DATA_TYPE_INT8, 192, 64, kSM_86, fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin,
386-
fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin_len,
409+
{DATA_TYPE_INT8, 192, 64, kSM_86, fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin,
410+
fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin_len,
387411
"fused_multihead_attention_v2_int8_192_64_kernel_sm80_noloop", 28672, 128, 32, false},
388-
{DATA_TYPE_INT8, 192, 64, kSM_86, fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin,
389-
fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin_len,
412+
{DATA_TYPE_INT8, 192, 64, kSM_86, fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin,
413+
fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin_len,
390414
"fused_multihead_attention_v2_int8_192_64_kernel_sm80_interleaved", 32768, 128, 0, true},
391-
{DATA_TYPE_INT8, 192, 64, kSM_86, fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin,
392-
fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin_len,
415+
{DATA_TYPE_INT8, 192, 64, kSM_86, fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin,
416+
fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin_len,
393417
"fused_multihead_attention_v2_int8_192_64_kernel_sm80", 32768, 128, 0, false},
394-
{DATA_TYPE_INT8, 256, 64, kSM_86, fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin,
395-
fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin_len,
418+
{DATA_TYPE_INT8, 256, 64, kSM_86, fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin,
419+
fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin_len,
396420
"fused_multihead_attention_v2_int8_256_64_kernel_sm80_interleaved_noloop", 36864, 128, 32, true},
397-
{DATA_TYPE_INT8, 256, 64, kSM_86, fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin,
398-
fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin_len,
421+
{DATA_TYPE_INT8, 256, 64, kSM_86, fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin,
422+
fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin_len,
399423
"fused_multihead_attention_v2_int8_256_64_kernel_sm80_noloop", 36864, 128, 32, false},
400-
{DATA_TYPE_INT8, 256, 64, kSM_86, fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin,
401-
fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin_len,
424+
{DATA_TYPE_INT8, 256, 64, kSM_86, fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin,
425+
fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin_len,
402426
"fused_multihead_attention_v2_int8_256_64_kernel_sm80_interleaved", 36864, 128, 0, true},
403-
{DATA_TYPE_INT8, 256, 64, kSM_86, fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin,
404-
fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin_len,
427+
{DATA_TYPE_INT8, 256, 64, kSM_86, fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin,
428+
fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin_len,
405429
"fused_multihead_attention_v2_int8_256_64_kernel_sm80", 36864, 128, 0, false},
406-
{DATA_TYPE_INT8, 384, 64, kSM_86, fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin,
407-
fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin_len,
430+
{DATA_TYPE_INT8, 384, 64, kSM_86, fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin,
431+
fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin_len,
408432
"fused_multihead_attention_v2_int8_384_64_kernel_sm80_interleaved_noloop", 53248, 128, 32, true},
409-
{DATA_TYPE_INT8, 384, 64, kSM_86, fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin,
410-
fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin_len,
433+
{DATA_TYPE_INT8, 384, 64, kSM_86, fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin,
434+
fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin_len,
411435
"fused_multihead_attention_v2_int8_384_64_kernel_sm80_noloop", 53248, 128, 32, false},
412-
{DATA_TYPE_INT8, 384, 64, kSM_86, fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin,
413-
fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin_len,
436+
{DATA_TYPE_INT8, 384, 64, kSM_86, fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin,
437+
fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin_len,
414438
"fused_multihead_attention_v2_int8_384_64_kernel_sm80_interleaved", 51200, 128, 0, true},
415-
{DATA_TYPE_INT8, 384, 64, kSM_86, fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin,
416-
fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin_len,
439+
{DATA_TYPE_INT8, 384, 64, kSM_86, fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin,
440+
fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin_len,
417441
"fused_multihead_attention_v2_int8_384_64_kernel_sm80", 53248, 128, 0, false},
418442
#endif
419443
};

0 commit comments

Comments
 (0)