@@ -385,6 +385,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
385
385
Location loc,
386
386
const TypeConverter *typeConverter,
387
387
bool isUnsigned, Value llvmInput,
388
+ Value mlirInput,
388
389
SmallVector<Value, 4 > &operands) {
389
390
Type inputType = llvmInput.getType ();
390
391
auto vectorType = dyn_cast<VectorType>(inputType);
@@ -398,23 +399,29 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
398
399
return ;
399
400
}
400
401
402
+ // We need to check the type of the input before conversion to properly test
403
+ // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
404
+ // fp8/int8 information is lost during the conversion process.
405
+ auto mlirInputType = cast<VectorType>(mlirInput.getType ());
406
+ bool isInputInt8 = mlirInputType.getElementType ().isInteger (8 );
407
+ if (isInputInt8) {
408
+ // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
409
+ bool localIsUnsigned = isUnsigned;
410
+ if (elemType.isUnsignedInteger (8 )) {
411
+ localIsUnsigned = true ;
412
+ } else if (elemType.isSignedInteger (8 )) {
413
+ localIsUnsigned = false ;
414
+ }
415
+ Value sign = createI1Constant (rewriter, loc, !localIsUnsigned);
416
+ operands.push_back (sign);
417
+ }
418
+
401
419
int64_t numBytes = vectorType.getNumElements ();
402
420
Type i32 = rewriter.getI32Type ();
403
421
VectorType vectorType32bits = VectorType::get (numBytes * 8 / 32 , i32);
404
422
auto llvmVectorType32bits = typeConverter->convertType (vectorType32bits);
405
-
406
423
Value result = rewriter.createOrFold <LLVM::BitcastOp>(
407
424
loc, llvmVectorType32bits, llvmInput);
408
-
409
- // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
410
- bool localIsUnsigned = isUnsigned;
411
- if (elemType.isUnsignedInteger (8 )) {
412
- localIsUnsigned = true ;
413
- } else if (elemType.isSignedInteger (8 )) {
414
- localIsUnsigned = false ;
415
- }
416
- Value sign = createI1Constant (rewriter, loc, !localIsUnsigned);
417
- operands.push_back (sign);
418
425
operands.push_back (result);
419
426
}
420
427
@@ -590,18 +597,20 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
590
597
auto elemSourceType = sourceVectorType.getElementType ();
591
598
auto elemDestType = destVectorType.getElementType ();
592
599
593
- if (elemSourceType.isF16 () && elemDestType.isF32 ()) {
600
+ if (elemSourceType.isF16 () && elemDestType.isF32 ())
594
601
return ROCDL::wmma_f32_16x16x16_f16::getOperationName ();
595
- }
596
- if (elemSourceType.isBF16 () && elemDestType.isF32 ()) {
602
+ if (elemSourceType.isBF16 () && elemDestType.isF32 ())
597
603
return ROCDL::wmma_f32_16x16x16_bf16::getOperationName ();
598
- } else if (elemSourceType.isF16 () && elemDestType.isF16 ()) {
604
+ if (elemSourceType.isF16 () && elemDestType.isF16 ())
599
605
return ROCDL::wmma_f16_16x16x16_f16::getOperationName ();
600
- } else if (elemSourceType.isBF16 () && elemDestType.isBF16 ()) {
606
+ if (elemSourceType.isBF16 () && elemDestType.isBF16 ())
601
607
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName ();
602
- } else if (elemSourceType.isInteger (8 ) && elemDestType.isInteger (32 )) {
608
+ if (elemSourceType.isInteger (8 ) && elemDestType.isInteger (32 ))
603
609
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName ();
604
- }
610
+ if (elemSourceType.isFloat8E4M3FN () && elemDestType.isF32 ())
611
+ return ROCDL::wmma_f32_16x16x16_fp8::getOperationName ();
612
+ if (elemSourceType.isFloat8E5M2 () && elemDestType.isF32 ())
613
+ return ROCDL::wmma_f32_16x16x16_bf8::getOperationName ();
605
614
return std::nullopt;
606
615
}
607
616
@@ -662,8 +671,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
662
671
Location loc = op.getLoc ();
663
672
Type outType = typeConverter->convertType (op.getDestD ().getType ());
664
673
665
- if (chipset.majorVersion != 11 )
666
- return op->emitOpError (" WMMA only supported on gfx11" );
674
+ if (chipset.majorVersion != 11 && chipset. majorVersion != 12 )
675
+ return op->emitOpError (" WMMA only supported on gfx11 and gfx12 " );
667
676
668
677
std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic (op, chipset);
669
678
@@ -675,9 +684,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
675
684
676
685
SmallVector<Value, 4 > operands;
677
686
wmmaPushInputOperand (rewriter, loc, typeConverter, op.getUnsignedA (),
678
- adaptor.getSourceA (), operands);
687
+ adaptor.getSourceA (), op. getSourceA (), operands);
679
688
wmmaPushInputOperand (rewriter, loc, typeConverter, op.getUnsignedB (),
680
- adaptor.getSourceB (), operands);
689
+ adaptor.getSourceB (), op. getSourceB (), operands);
681
690
wmmaPushOutputOperand (rewriter, loc, typeConverter, adaptor.getDestC (),
682
691
op.getSubwordOffset (), op.getClamp (), operands);
683
692
0 commit comments