Skip to content

Commit 730f498

Browse files
authored
[mlir][arith] Improve truncf folding (llvm#80206)
* Use APFloat conversion function instead of going through double to check if fold results in information loss. * Support folding vector constants.
1 parent 8ba018d commit 730f498

File tree

2 files changed

+25
-16
lines changed

2 files changed

+25
-16
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
#include "mlir/IR/PatternMatch.h"
2323
#include "mlir/IR/TypeUtilities.h"
2424

25+
#include "llvm/ADT/APFloat.h"
2526
#include "llvm/ADT/APInt.h"
2627
#include "llvm/ADT/APSInt.h"
28+
#include "llvm/ADT/FloatingPointMode.h"
2729
#include "llvm/ADT/STLExtras.h"
2830
#include "llvm/ADT/SmallString.h"
2931
#include "llvm/ADT/SmallVector.h"
@@ -1393,23 +1395,21 @@ LogicalResult arith::TruncIOp::verify() {
13931395
// TruncFOp
13941396
//===----------------------------------------------------------------------===//
13951397

1396-
/// Perform safe const propagation for truncf, i.e. only propagate if FP value
1397-
/// can be represented without precision loss or rounding.
1398+
/// Perform safe const propagation for truncf, i.e., only propagate if FP value
1399+
/// can be represented without precision loss or rounding. This is because the
1400+
/// semantics of `arith.truncf` do not assume a specific rounding mode.
13981401
OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1399-
auto constOperand = adaptor.getIn();
1400-
if (!constOperand || !llvm::isa<FloatAttr>(constOperand))
1401-
return {};
1402-
1403-
// Convert to target type via 'double'.
1404-
double sourceValue =
1405-
llvm::dyn_cast<FloatAttr>(constOperand).getValue().convertToDouble();
1406-
auto targetAttr = FloatAttr::get(getType(), sourceValue);
1407-
1408-
// Propagate if constant's value does not change after truncation.
1409-
if (sourceValue == targetAttr.getValue().convertToDouble())
1410-
return targetAttr;
1411-
1412-
return {};
1402+
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1403+
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1404+
return constFoldCastOp<FloatAttr, FloatAttr>(
1405+
adaptor.getOperands(), getType(),
1406+
[&targetSemantics](APFloat a, bool &castStatus) {
1407+
bool losesInfo = false;
1408+
auto status = a.convert(
1409+
targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
1410+
castStatus = !losesInfo && status == APFloat::opOK;
1411+
return a;
1412+
});
14131413
}
14141414

14151415
bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,15 @@ func.func @truncFPConstant() -> bf16 {
825825
return %0 : bf16
826826
}
827827

828+
// CHECK-LABEL: @truncFPVectorConstant
829+
// CHECK: %[[cres:.+]] = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xbf16>
830+
// CHECK: return %[[cres]]
831+
func.func @truncFPVectorConstant() -> vector<2xbf16> {
832+
%cst = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xf32>
833+
%0 = arith.truncf %cst : vector<2xf32> to vector<2xbf16>
834+
return %0 : vector<2xbf16>
835+
}
836+
828837
// Test that cases with rounding are NOT propagated
829838
// CHECK-LABEL: @truncFPConstantRounding
830839
// CHECK: arith.constant 1.444000e+25 : f32

0 commit comments

Comments
 (0)