Skip to content

Commit 83cad68

Browse files
authored
[MLIR][NVVM] Update Float to TF32 conversion Op (#125048)
This change updates the Float to TF32 conversion MLIR Op to include lowering to the new intrinsics introduced in sm_100 through ptx8.6: - `nvvm_f2tf32_rn_satfinite` - `nvvm_f2tf32_rn_relu_satfinite` - `nvvm_f2tf32_rz_satfinite` - `nvvm_f2tf32_rz_relu_satfinite` PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
1 parent 028b690 commit 83cad68

File tree

3 files changed

+40
-26
lines changed

3 files changed

+40
-26
lines changed

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,6 @@ LogicalResult CvtFloatToTF32Op::verify() {
147147
break;
148148
case RndMode::RN:
149149
case RndMode::RZ:
150-
if (getSat() != NVVM::SaturationMode::NONE)
151-
return emitError(
152-
"Saturation mode not supported with rn/rz rounding modes.");
153150
break;
154151
default:
155152
return emitError(
@@ -1221,21 +1218,26 @@ llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
12211218
llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
12221219
}
12231220

1221+
#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1222+
hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
1223+
: llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
1224+
1225+
#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
1226+
hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1227+
: CVT_F2TF32_ID_IMPL(rnd, relu, )
1228+
12241229
llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
12251230
NVVM::SaturationMode sat,
12261231
bool hasRelu) {
12271232
using RndMode = NVVM::FPRoundingMode;
1233+
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
12281234
switch (rnd) {
12291235
case RndMode::RN:
1230-
return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rn_relu
1231-
: llvm::Intrinsic::nvvm_f2tf32_rn;
1236+
return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
12321237
case RndMode::RZ:
1233-
return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rz_relu
1234-
: llvm::Intrinsic::nvvm_f2tf32_rz;
1238+
return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
12351239
case RndMode::RNA:
1236-
return (sat == NVVM::SaturationMode::SATFINITE)
1237-
? llvm::Intrinsic::nvvm_f2tf32_rna_satfinite
1238-
: llvm::Intrinsic::nvvm_f2tf32_rna;
1240+
return GET_CVT_F2TF32_ID(rna, , _satfinite);
12391241
default:
12401242
llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
12411243
}

mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@ llvm.func @convert_float_to_tf32_rn_relu(%src : f32) -> i32 {
2828
llvm.return %res : i32
2929
}
3030

31+
// CHECK-LABEL: @convert_float_to_tf32_rn_sf
32+
llvm.func @convert_float_to_tf32_rn_sf(%src : f32) -> i32 {
33+
// CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn.satfinite(float %{{.*}})
34+
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>}
35+
llvm.return %res : i32
36+
}
37+
38+
// CHECK-LABEL: @convert_float_to_tf32_rn_relu_sf
39+
llvm.func @convert_float_to_tf32_rn_relu_sf(%src : f32) -> i32 {
40+
// CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn.relu.satfinite(float %{{.*}})
41+
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, relu=true, sat = #nvvm.sat_mode<satfinite>}
42+
llvm.return %res : i32
43+
}
44+
3145
// CHECK-LABEL: @convert_float_to_tf32_rz
3246
llvm.func @convert_float_to_tf32_rz(%src : f32) -> i32 {
3347
// CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz(float %{{.*}})
@@ -41,3 +55,17 @@ llvm.func @convert_float_to_tf32_rz_relu(%src : f32) -> i32 {
4155
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, relu=true}
4256
llvm.return %res : i32
4357
}
58+
59+
// CHECK-LABEL: @convert_float_to_tf32_rz_sf
60+
llvm.func @convert_float_to_tf32_rz_sf(%src : f32) -> i32 {
61+
// CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz.satfinite(float %{{.*}})
62+
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>}
63+
llvm.return %res : i32
64+
}
65+
66+
// CHECK-LABEL: @convert_float_to_tf32_rz_relu_sf
67+
llvm.func @convert_float_to_tf32_rz_relu_sf(%src : f32) -> i32 {
68+
// CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz.relu.satfinite(float %{{.*}})
69+
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, relu=true, sat = #nvvm.sat_mode<satfinite>}
70+
llvm.return %res : i32
71+
}

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -117,22 +117,6 @@ llvm.func @convert_float_to_tf32_rna_relu(%src : f32) -> i32 {
117117

118118
// -----
119119

120-
llvm.func @convert_float_to_tf32_rn_sf(%src : f32) -> i32 {
121-
// expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
122-
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>}
123-
llvm.return %res : i32
124-
}
125-
126-
// -----
127-
128-
llvm.func @convert_float_to_tf32_rz_sf(%src : f32) -> i32 {
129-
// expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
130-
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>}
131-
llvm.return %res : i32
132-
}
133-
134-
// -----
135-
136120
llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 {
137121
// expected-error @below {{Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.}}
138122
%res = nvvm.cvt.float.to.tf32 %src

0 commit comments

Comments
 (0)