|
22 | 22 | #include "mlir/IR/PatternMatch.h"
|
23 | 23 | #include "mlir/IR/TypeUtilities.h"
|
24 | 24 |
|
| 25 | +#include "llvm/ADT/APFloat.h" |
25 | 26 | #include "llvm/ADT/APInt.h"
|
26 | 27 | #include "llvm/ADT/APSInt.h"
|
| 28 | +#include "llvm/ADT/FloatingPointMode.h" |
27 | 29 | #include "llvm/ADT/STLExtras.h"
|
28 | 30 | #include "llvm/ADT/SmallString.h"
|
29 | 31 | #include "llvm/ADT/SmallVector.h"
|
@@ -1393,23 +1395,21 @@ LogicalResult arith::TruncIOp::verify() {
|
1393 | 1395 | // TruncFOp
|
1394 | 1396 | //===----------------------------------------------------------------------===//
|
1395 | 1397 |
|
1396 |
| -/// Perform safe const propagation for truncf, i.e. only propagate if FP value |
1397 |
| -/// can be represented without precision loss or rounding. |
| 1398 | +/// Perform safe const propagation for truncf, i.e., only propagate if FP value |
| 1399 | +/// can be represented without precision loss or rounding. This is because the |
| 1400 | +/// semantics of `arith.truncf` do not assume a specific rounding mode. |
1398 | 1401 | OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
|
1399 |
| - auto constOperand = adaptor.getIn(); |
1400 |
| - if (!constOperand || !llvm::isa<FloatAttr>(constOperand)) |
1401 |
| - return {}; |
1402 |
| - |
1403 |
| - // Convert to target type via 'double'. |
1404 |
| - double sourceValue = |
1405 |
| - llvm::dyn_cast<FloatAttr>(constOperand).getValue().convertToDouble(); |
1406 |
| - auto targetAttr = FloatAttr::get(getType(), sourceValue); |
1407 |
| - |
1408 |
| - // Propagate if constant's value does not change after truncation. |
1409 |
| - if (sourceValue == targetAttr.getValue().convertToDouble()) |
1410 |
| - return targetAttr; |
1411 |
| - |
1412 |
| - return {}; |
| 1402 | + auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType())); |
| 1403 | + const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics(); |
| 1404 | + return constFoldCastOp<FloatAttr, FloatAttr>( |
| 1405 | + adaptor.getOperands(), getType(), |
| 1406 | + [&targetSemantics](APFloat a, bool &castStatus) { |
| 1407 | + bool losesInfo = false; |
| 1408 | + auto status = a.convert( |
| 1409 | + targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo); |
| 1410 | + castStatus = !losesInfo && status == APFloat::opOK; |
| 1411 | + return a; |
| 1412 | + }); |
1413 | 1413 | }
|
1414 | 1414 |
|
1415 | 1415 | bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
0 commit comments