Skip to content

[mlir][arith] Improve truncf folding #80206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 1, 2024
Merged

[mlir][arith] Improve truncf folding #80206

merged 3 commits into from
Feb 1, 2024

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Jan 31, 2024

  • Use APFloat conversion function instead of going through double to check if fold results in information loss.
  • Support folding vector constants.

* Use APFloat conversion function instead of going through double to check
if fold results in information loss.
* Support folding vector constants.
@kuhar kuhar marked this pull request as ready for review January 31, 2024 21:57
@llvmbot
Copy link
Member

llvmbot commented Jan 31, 2024

@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Jakub Kuderski (kuhar)

Changes
  • Use APFloat conversion function instead of going through double to check if fold results in information loss.
  • Support folding vector constants.

Full diff: https://github.com/llvm/llvm-project/pull/80206.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+15-15)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+9)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index ff72becc8dfa7..a02f7d6dd5053 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -22,8 +22,10 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/APSInt.h"
+#include "llvm/ADT/FloatingPointMode.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/SmallVector.h"
@@ -1393,23 +1395,21 @@ LogicalResult arith::TruncIOp::verify() {
 // TruncFOp
 //===----------------------------------------------------------------------===//
 
-/// Perform safe const propagation for truncf, i.e. only propagate if FP value
+/// Perform safe const propagation for truncf, i.e., only propagate if FP value
 /// can be represented without precision loss or rounding.
 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
-  auto constOperand = adaptor.getIn();
-  if (!constOperand || !llvm::isa<FloatAttr>(constOperand))
-    return {};
-
-  // Convert to target type via 'double'.
-  double sourceValue =
-      llvm::dyn_cast<FloatAttr>(constOperand).getValue().convertToDouble();
-  auto targetAttr = FloatAttr::get(getType(), sourceValue);
-
-  // Propagate if constant's value does not change after truncation.
-  if (sourceValue == targetAttr.getValue().convertToDouble())
-    return targetAttr;
-
-  return {};
+  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
+  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
+  return constFoldCastOp<FloatAttr, FloatAttr>(
+      adaptor.getOperands(), getType(),
+      [&targetSemantics](APFloat a, bool &castStatus) {
+        bool loosesInfo = false;
+        auto status =
+            a.convert(targetSemantics, llvm::RoundingMode::NearestTiesToEven,
+                      &loosesInfo);
+        castStatus = !loosesInfo && status == APFloat::opOK;
+        return a;
+      });
 }
 
 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 10050d87d7568..44df11ab2433a 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -825,6 +825,15 @@ func.func @truncFPConstant() -> bf16 {
   return %0 : bf16
 }
 
+// CHECK-LABEL: @truncFPVectorConstant
+//       CHECK:   %[[cres:.+]] = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xbf16>
+//       CHECK:   return %[[cres]]
+func.func @truncFPVectorConstant() -> vector<2xbf16> {
+  %cst = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xf32>
+  %0 = arith.truncf %cst : vector<2xf32> to vector<2xbf16>
+  return %0 : vector<2xbf16>
+}
+
 // Test that cases with rounding are NOT propagated
 // CHECK-LABEL: @truncFPConstantRounding
 //       CHECK:   arith.constant 1.444000e+25 : f32

@kuhar kuhar merged commit 730f498 into llvm:main Feb 1, 2024
carlosgalvezp pushed a commit to carlosgalvezp/llvm-project that referenced this pull request Feb 1, 2024
* Use APFloat conversion function instead of going through double to
check if fold results in information loss.
* Support folding vector constants.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants