Skip to content

Commit a8e1c6f

Browse files
authored
[MLIR][AMDGPU] Add support for fp8 ops on gfx12 (#106388)
This PR is adding support for `fp8` and `bfp8` on gfx12
1 parent 0b2f253 commit a8e1c6f

File tree

6 files changed

+59
-27
lines changed

6 files changed

+59
-27
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ def MFMAOutTypes : AnyTypeOf<[F64,
552552
VectorOfLengthAndType<[4, 16, 32], [I32]>,
553553
VectorOfLengthAndType<[4], [F64]>]>;
554554
// wmma
555-
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[16], [F16, BF16, I8, SI8, UI8]>]>;
555+
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[8, 16], [F16, BF16, I8, SI8, UI8, F8E4M3FN, F8E5M2]>]>;
556556
def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
557557
VectorOfLengthAndType<[8, 16], [F16, BF16]>]>;
558558

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

+4-1
Original file line numberDiff line numberDiff line change
@@ -329,13 +329,16 @@ class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
329329
"$args attr-dict `:` functional-type($args, $res)";
330330
}
331331

332-
// Available on RDNA3
332+
// Available from gfx11
333333
def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>;
334334
def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>;
335335
def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>;
336336
def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>;
337337
def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>;
338338
def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>;
339+
// Available from gfx12
340+
def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
341+
def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
339342

340343
//===---------------------------------------------------------------------===//
341344
// Operations on raw buffer resources (stride of 0, bounds checks either off or in

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

+31-22
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
385385
Location loc,
386386
const TypeConverter *typeConverter,
387387
bool isUnsigned, Value llvmInput,
388+
Value mlirInput,
388389
SmallVector<Value, 4> &operands) {
389390
Type inputType = llvmInput.getType();
390391
auto vectorType = dyn_cast<VectorType>(inputType);
@@ -398,23 +399,29 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
398399
return;
399400
}
400401

402+
// We need to check the type of the input before conversion to properly test
403+
// for int8. This is because, in LLVM, fp8 type is converted to int8, so the
404+
// fp8/int8 information is lost during the conversion process.
405+
auto mlirInputType = cast<VectorType>(mlirInput.getType());
406+
bool isInputInt8 = mlirInputType.getElementType().isInteger(8);
407+
if (isInputInt8) {
408+
// if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
409+
bool localIsUnsigned = isUnsigned;
410+
if (elemType.isUnsignedInteger(8)) {
411+
localIsUnsigned = true;
412+
} else if (elemType.isSignedInteger(8)) {
413+
localIsUnsigned = false;
414+
}
415+
Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
416+
operands.push_back(sign);
417+
}
418+
401419
int64_t numBytes = vectorType.getNumElements();
402420
Type i32 = rewriter.getI32Type();
403421
VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
404422
auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
405-
406423
Value result = rewriter.createOrFold<LLVM::BitcastOp>(
407424
loc, llvmVectorType32bits, llvmInput);
408-
409-
// if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
410-
bool localIsUnsigned = isUnsigned;
411-
if (elemType.isUnsignedInteger(8)) {
412-
localIsUnsigned = true;
413-
} else if (elemType.isSignedInteger(8)) {
414-
localIsUnsigned = false;
415-
}
416-
Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
417-
operands.push_back(sign);
418425
operands.push_back(result);
419426
}
420427

@@ -590,18 +597,20 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
590597
auto elemSourceType = sourceVectorType.getElementType();
591598
auto elemDestType = destVectorType.getElementType();
592599

593-
if (elemSourceType.isF16() && elemDestType.isF32()) {
600+
if (elemSourceType.isF16() && elemDestType.isF32())
594601
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
595-
}
596-
if (elemSourceType.isBF16() && elemDestType.isF32()) {
602+
if (elemSourceType.isBF16() && elemDestType.isF32())
597603
return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
598-
} else if (elemSourceType.isF16() && elemDestType.isF16()) {
604+
if (elemSourceType.isF16() && elemDestType.isF16())
599605
return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
600-
} else if (elemSourceType.isBF16() && elemDestType.isBF16()) {
606+
if (elemSourceType.isBF16() && elemDestType.isBF16())
601607
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
602-
} else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) {
608+
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
603609
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
604-
}
610+
if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32())
611+
return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
612+
if (elemSourceType.isFloat8E5M2() && elemDestType.isF32())
613+
return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
605614
return std::nullopt;
606615
}
607616

@@ -662,8 +671,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
662671
Location loc = op.getLoc();
663672
Type outType = typeConverter->convertType(op.getDestD().getType());
664673

665-
if (chipset.majorVersion != 11)
666-
return op->emitOpError("WMMA only supported on gfx11");
674+
if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
675+
return op->emitOpError("WMMA only supported on gfx11 and gfx12");
667676

668677
std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
669678

@@ -675,9 +684,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
675684

676685
SmallVector<Value, 4> operands;
677686
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
678-
adaptor.getSourceA(), operands);
687+
adaptor.getSourceA(), op.getSourceA(), operands);
679688
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
680-
adaptor.getSourceB(), operands);
689+
adaptor.getSourceB(), op.getSourceB(), operands);
681690
wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
682691
op.getSubwordOffset(), op.getClamp(), operands);
683692

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,10 @@ LogicalResult WMMAOp::verify() {
234234
Type sourceAElemType = sourceVectorAType.getElementType();
235235
Type destElemType = destVectorType.getElementType();
236236

237-
bool isDestFloat =
238-
(destElemType.isF32() || destElemType.isF16() || destElemType.isBF16());
239-
bool isSrcFloat = (sourceAElemType.isF16() || sourceAElemType.isBF16());
237+
bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
238+
bool isSrcFloat =
239+
isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
240+
sourceAElemType);
240241

241242
if (isDestFloat && !isSrcFloat) {
242243
return emitOpError("Expected float sources with float destination");
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
2+
func.func @mfma_to_rocdl(%arg0 : vector<8xf8E4M3FN>, %arg1 : vector<8xf8E5M2>, %arg2 : vector<8xf32>) {
3+
// CHECK: rocdl.wmma.f32.16x16x16.fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
4+
amdgpu.wmma %arg0 * %arg0 + %arg2: vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>
5+
6+
// CHECK: rocdl.wmma.f32.16x16x16.bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
7+
amdgpu.wmma %arg1 * %arg1 + %arg2: vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
8+
func.return
9+
}

mlir/test/Target/LLVMIR/rocdl.mlir

+10
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,16 @@ llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
377377
llvm.return %rsrc : !llvm.ptr<8>
378378
}
379379

380+
llvm.func @rocdl.wmma.fp8(%arg0 : vector<2 x i32>, %arg1 : vector<8xf32>) -> vector<8xf32> {
381+
// CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.fp8.fp8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
382+
%r0 = rocdl.wmma.f32.16x16x16.fp8_fp8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
383+
384+
// CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf8.bf8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
385+
%r1 = rocdl.wmma.f32.16x16x16.bf8_bf8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
386+
387+
llvm.return %r0 : vector<8 x f32>
388+
}
389+
380390
llvm.func @rocdl.raw.ptr.buffer(%rsrc : !llvm.ptr<8>,
381391
%offset : i32, %soffset : i32,
382392
%vdata1 : i32,

0 commit comments

Comments
 (0)