Skip to content

Commit 10b56e0

Browse files
committed
[mlir][Arith] Add pass for emulating unsupported float ops (#1079)
To complement the bf16 expansion and truncation patterns added to ExpandOps, define a pass that replaces, for any arithmetic operation op, %y = arith.op %v0, %v1, ... : T with %e0 = arith.expf %v0 : T to U %e1 = arith.expf %v1 : T to U ... %y.exp = arith.op %e0, %e1, ... : U %y = arith.truncf %y.exp : U to T This allows for "emulating" floating-point operations not supported on a given target (such as bfloat operations or most arithmetic on 8-bit floats) by extending those types to supported ones, performing the arithmetic operation, and then truncating back to the original type (which ensures appropriate rounding behavior). The lowering of the extf and truncf ops introduced by this transformation should be handled by subsequent passes. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D154539
1 parent 980cd18 commit 10b56e0

File tree

5 files changed

+311
-0
lines changed

5 files changed

+311
-0
lines changed

mlir/include/mlir/Dialect/Arith/Transforms/Passes.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
namespace mlir {
1515
class DataFlowSolver;
16+
class ConversionTarget;
17+
class TypeConverter;
1618

1719
namespace arith {
1820

@@ -42,6 +44,21 @@ void populateArithWideIntEmulationPatterns(
4244
void populateArithNarrowTypeEmulationPatterns(
4345
NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns);
4446

47+
/// Populate the type conversions needed to emulate the unsupported
48+
/// `sourceTypes` with `destType`
49+
void populateEmulateUnsupportedFloatsConversions(TypeConverter &converter,
50+
ArrayRef<Type> sourceTypes,
51+
Type targetType);
52+
53+
/// Add rewrite patterns for converting operations that use illegal float types
54+
/// to ones that use legal ones.
55+
void populateEmulateUnsupportedFloatsPatterns(RewritePatternSet &patterns,
56+
TypeConverter &converter);
57+
58+
/// Set up a dialect conversion to reject arithmetic operations on unsupported
59+
/// float types.
60+
void populateEmulateUnsupportedFloatsLegality(ConversionTarget &target,
61+
TypeConverter &converter);
4562
/// Add patterns to expand Arith ceil/floor division ops.
4663
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
4764

mlir/include/mlir/Dialect/Arith/Transforms/Passes.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,28 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
6363
}];
6464
}
6565

66+
def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
67+
let summary = "Emulate operations on unsupported floats with extf/truncf";
68+
let description = [{
69+
Emulate arith and vector floating point operations that use float types
70+
which are unspported on a target by inserting extf/truncf pairs around all
71+
such operations in order to produce arithmetic that can be performed while
72+
preserving the original rounding behavior.
73+
74+
This pass does not attempt to reason about the operations being performed
75+
to determine when type conversions can be elided.
76+
}];
77+
78+
let options = [
79+
ListOption<"sourceTypeStrs", "source-types", "std::string",
80+
"MLIR types without arithmetic support on a given target">,
81+
Option<"targetTypeStr", "target-type", "std::string", "\"f32\"",
82+
"MLIR type to convert the unsupported source types to">,
83+
];
84+
85+
let dependentDialects = ["vector::VectorDialect"];
86+
}
87+
6688
def ArithEmulateWideInt : Pass<"arith-emulate-wide-int"> {
6789
let summary = "Emulate 2*N-bit integer operations using N-bit operations";
6890
let description = [{

mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_dialect_library(MLIRArithTransforms
22
BufferizableOpInterfaceImpl.cpp
33
Bufferize.cpp
4+
EmulateUnsupportedFloats.cpp
45
EmulateWideInt.cpp
56
EmulateNarrowType.cpp
67
ExpandOps.cpp
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
//===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// This pass promotes small floats (of some unsupported types T) to a supported
9+
// type U by wrapping all float operations on Ts with expansion to and
10+
// truncation from U, then operating on U.
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Arith/Transforms/Passes.h"
14+
15+
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
17+
#include "mlir/IR/BuiltinTypes.h"
18+
#include "mlir/IR/Location.h"
19+
#include "mlir/IR/PatternMatch.h"
20+
#include "mlir/Transforms/DialectConversion.h"
21+
#include "llvm/ADT/STLExtras.h"
22+
#include <optional>
23+
24+
namespace mlir::arith {
25+
#define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS
26+
#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
27+
} // namespace mlir::arith
28+
29+
using namespace mlir;
30+
31+
namespace {
32+
struct EmulateUnsupportedFloatsPass
33+
: arith::impl::ArithEmulateUnsupportedFloatsBase<
34+
EmulateUnsupportedFloatsPass> {
35+
using arith::impl::ArithEmulateUnsupportedFloatsBase<
36+
EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase;
37+
38+
void runOnOperation() override;
39+
};
40+
41+
struct EmulateFloatPattern final : ConversionPattern {
42+
EmulateFloatPattern(TypeConverter &converter, MLIRContext *ctx)
43+
: ConversionPattern(converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
44+
45+
LogicalResult match(Operation *op) const override;
46+
void rewrite(Operation *op, ArrayRef<Value> operands,
47+
ConversionPatternRewriter &rewriter) const override;
48+
};
49+
} // end namespace
50+
51+
/// Map strings to float types. This function is here because no one else needs
52+
/// it yet, feel free to abstract it out.
53+
static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
54+
StringRef name) {
55+
Builder b(ctx);
56+
return llvm::StringSwitch<std::optional<FloatType>>(name)
57+
.Case("f8E5M2", b.getFloat8E5M2Type())
58+
.Case("f8E4M3FN", b.getFloat8E4M3FNType())
59+
.Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
60+
.Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
61+
.Case("bf16", b.getBF16Type())
62+
.Case("f16", b.getF16Type())
63+
.Case("f32", b.getF32Type())
64+
.Case("f64", b.getF64Type())
65+
.Case("f80", b.getF80Type())
66+
.Case("f128", b.getF128Type())
67+
.Default(std::nullopt);
68+
}
69+
70+
LogicalResult EmulateFloatPattern::match(Operation *op) const {
71+
if (getTypeConverter()->isLegal(op))
72+
return failure();
73+
// The rewrite doesn't handle cloning regions.
74+
if (op->getNumRegions() != 0)
75+
return failure();
76+
return success();
77+
}
78+
79+
void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
80+
ConversionPatternRewriter &rewriter) const {
81+
Location loc = op->getLoc();
82+
TypeConverter *converter = getTypeConverter();
83+
SmallVector<Type> resultTypes;
84+
assert(
85+
succeeded(converter->convertTypes(op->getResultTypes(), resultTypes)) &&
86+
"type conversions shouldn't fail in this pass");
87+
Operation *expandedOp =
88+
rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes,
89+
op->getAttrs(), op->getSuccessors(), /*regions=*/{});
90+
SmallVector<Value> newResults(expandedOp->getResults());
91+
for (auto [res, oldType, newType] : llvm::zip_equal(
92+
MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
93+
if (oldType != newType)
94+
res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
95+
}
96+
rewriter.replaceOp(op, newResults);
97+
}
98+
99+
void mlir::arith::populateEmulateUnsupportedFloatsConversions(
100+
TypeConverter &converter, ArrayRef<Type> sourceTypes, Type targetType) {
101+
converter.addConversion([sourceTypes = SmallVector<Type>(sourceTypes),
102+
targetType](Type type) -> std::optional<Type> {
103+
if (llvm::is_contained(sourceTypes, type))
104+
return targetType;
105+
if (auto shaped = type.dyn_cast<ShapedType>())
106+
if (llvm::is_contained(sourceTypes, shaped.getElementType()))
107+
return shaped.clone(targetType);
108+
// All other types legal
109+
return type;
110+
});
111+
converter.addTargetMaterialization(
112+
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
113+
return b.create<arith::ExtFOp>(loc, target, input);
114+
});
115+
}
116+
117+
void mlir::arith::populateEmulateUnsupportedFloatsPatterns(
118+
RewritePatternSet &patterns, TypeConverter &converter) {
119+
patterns.add<EmulateFloatPattern>(converter, patterns.getContext());
120+
}
121+
122+
void mlir::arith::populateEmulateUnsupportedFloatsLegality(
123+
ConversionTarget &target, TypeConverter &converter) {
124+
// Don't try to legalize functions and other ops that don't need expansion.
125+
target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
126+
target.addDynamicallyLegalDialect<arith::ArithDialect>(
127+
[&](Operation *op) -> std::optional<bool> {
128+
return converter.isLegal(op);
129+
});
130+
// Manually mark arithmetic-performing vector instructions.
131+
target.addDynamicallyLegalOp<
132+
vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
133+
vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
134+
[&](Operation *op) { return converter.isLegal(op); });
135+
target.addLegalOp<arith::ExtFOp, arith::TruncFOp, arith::ConstantOp,
136+
vector::SplatOp>();
137+
}
138+
139+
void EmulateUnsupportedFloatsPass::runOnOperation() {
140+
MLIRContext *ctx = &getContext();
141+
Operation *op = getOperation();
142+
SmallVector<Type> sourceTypes;
143+
Type targetType;
144+
145+
std::optional<FloatType> maybeTargetType = parseFloatType(ctx, targetTypeStr);
146+
if (!maybeTargetType) {
147+
emitError(UnknownLoc::get(ctx), "could not map target type '" +
148+
targetTypeStr +
149+
"' to a known floating-point type");
150+
return signalPassFailure();
151+
}
152+
targetType = *maybeTargetType;
153+
for (StringRef sourceTypeStr : sourceTypeStrs) {
154+
std::optional<FloatType> maybeSourceType =
155+
parseFloatType(ctx, sourceTypeStr);
156+
if (!maybeSourceType) {
157+
emitError(UnknownLoc::get(ctx), "could not map source type '" +
158+
sourceTypeStr +
159+
"' to a known floating-point type");
160+
return signalPassFailure();
161+
}
162+
sourceTypes.push_back(*maybeSourceType);
163+
}
164+
if (sourceTypes.empty())
165+
(void)emitOptionalWarning(
166+
std::nullopt,
167+
"no source types specified, float emulation will do nothing");
168+
169+
if (llvm::is_contained(sourceTypes, targetType)) {
170+
emitError(UnknownLoc::get(ctx),
171+
"target type cannot be an unsupported source type");
172+
return signalPassFailure();
173+
}
174+
TypeConverter converter;
175+
arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes,
176+
targetType);
177+
RewritePatternSet patterns(ctx);
178+
arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter);
179+
ConversionTarget target(getContext());
180+
arith::populateEmulateUnsupportedFloatsLegality(target, converter);
181+
182+
if (failed(applyPartialConversion(op, target, std::move(patterns))))
183+
signalPassFailure();
184+
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// RUN: mlir-opt --split-input-file --arith-emulate-unsupported-floats="source-types=bf16,f8E4M3FNUZ target-type=f32" %s | FileCheck %s
2+
3+
func.func @basic_expansion(%x: bf16) -> bf16 {
4+
// CHECK-LABEL: @basic_expansion
5+
// CHECK-SAME: [[X:%.+]]: bf16
6+
// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16
7+
// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
8+
// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] : bf16 to f32
9+
// CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32
10+
// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] : f32 to bf16
11+
// CHECK: return [[Y]]
12+
%c = arith.constant 1.0 : bf16
13+
%y = arith.addf %x, %c : bf16
14+
func.return %y : bf16
15+
}
16+
17+
// -----
18+
19+
func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
20+
// CHECK-LABEL: @chained
21+
// CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16
22+
// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
23+
// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] : bf16 to f32
24+
// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] : bf16 to f32
25+
// CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32
26+
// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] : f32 to bf16
27+
// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] : bf16 to f32
28+
// CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]]
29+
// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] : f32 to bf16
30+
// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] : bf16 to f32
31+
// CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32
32+
// CHECK: return [[RES]]
33+
%p = arith.addf %x, %y : bf16
34+
%q = arith.mulf %p, %z : bf16
35+
%res = arith.cmpf ole, %p, %q : bf16
36+
func.return %res : i1
37+
}
38+
39+
// -----
40+
41+
func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
42+
// CHECK-LABEL: @memops
43+
// CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ>
44+
// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] : f8E4M3FNUZ to f32
45+
// CHECK: memref.store [[V]]
46+
// CHECK: [[W:%.+]] = memref.load
47+
// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] : f8E4M3FNUZ to f32
48+
// CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32
49+
// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] : f32 to f8E4M3FNUZ
50+
// CHECK: memref.store [[X]]
51+
%c0 = arith.constant 0 : index
52+
%c1 = arith.constant 1 : index
53+
%v = memref.load %a[%c0] : memref<4xf8E4M3FNUZ>
54+
memref.store %v, %b[%c0] : memref<4xf8E4M3FNUZ>
55+
%w = memref.load %a[%c1] : memref<4xf8E4M3FNUZ>
56+
%x = arith.addf %v, %w : f8E4M3FNUZ
57+
memref.store %x, %b[%c1] : memref<4xf8E4M3FNUZ>
58+
func.return
59+
}
60+
61+
// -----
62+
63+
func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> {
64+
// CHECK-LABEL: @vectors
65+
// CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ>
66+
// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
67+
// CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32>
68+
// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] : vector<4xf32> to vector<4xf8E4M3FNUZ>
69+
// CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
70+
// CHECK: return [[RET]]
71+
%b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ>
72+
%ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32>
73+
func.return %ret : vector<4xf32>
74+
}
75+
76+
// -----
77+
78+
func.func @no_expansion(%x: f32) -> f32 {
79+
// CHECK-LABEL: @no_expansion
80+
// CHECK-SAME: [[X:%.+]]: f32
81+
// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : f32
82+
// CHECK: [[Y:%.+]] = arith.addf [[X]], [[C]] : f32
83+
// CHECK: return [[Y]]
84+
%c = arith.constant 1.0 : f32
85+
%y = arith.addf %x, %c : f32
86+
func.return %y : f32
87+
}

0 commit comments

Comments
 (0)