diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index e18fa849e9f30..9ca93ab28daed 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -2285,8 +2285,22 @@ class ArgMaxConverter : public OpRewritePattern { Value predicate; if (isa(inElementTy)) { - predicate = rewriter.create( - nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); + if (argmaxOp.getNanMode() == "IGNORE") { + // Only update index & max value for non NaN values. If all + // values are NaNs, the initial index will be return which is 0. + predicate = rewriter.create( + nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); + } else { + // Update max value if either of the following is true: + // - new value is bigger + // - cur max is not NaN and new value is NaN + Value gt = rewriter.create( + nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue); + Value oldNonNaN = rewriter.create( + nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue); + predicate = rewriter.create( + nestedLoc, rewriter.getI1Type(), gt, oldNonNaN); + } } else if (isa(inElementTy)) { predicate = rewriter.create( nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue); @@ -2299,28 +2313,6 @@ class ArgMaxConverter : public OpRewritePattern { nestedLoc, predicate, newValue, oldValue); auto resultIndex = rewriter.create( nestedLoc, predicate, newIndex, oldIndex); - - // Check if we need to materialize compare and select for the given - // NaN propagation mode. - - // "PROPAGATE" matches the default NaN propagation mode of the arith - // dialect so no compare and select is required. - // - // In the case "IGNORE" we check if the current argument is NaN and - // select the old index and value otherwise take the updated index and - // value. - if (const auto nanMode = argmaxOp.getNanMode(); - isa(inElementTy) && nanMode == "IGNORE") { - // Unordered comparison of NaN against itself will always return - // true. - Value isNaN = rewriter.create( - argmaxOp.getLoc(), arith::CmpFPredicate::UNO, newValue, - newValue); - resultMax = rewriter.create(nestedLoc, isNaN, - oldValue, resultMax); - resultIndex = rewriter.create( - nestedLoc, isNaN, oldIndex, resultIndex); - } nestedBuilder.create( nestedLoc, ValueRange({resultIndex, resultMax})); }); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 9258442de5a45..eafc62eb71e05 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1525,7 +1525,9 @@ func.func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () { // CHECK: arith.constant -3.40282347E+38 : f32 // CHECK: linalg.index // CHECK: arith.index_cast - // CHECK: arith.cmpf ogt + // CHECK: arith.cmpf ugt + // CHECK: arith.cmpf ord + // CHECK: andi // CHECK: select // CHECK: select // CHECK: linalg.yield @@ -2230,12 +2232,12 @@ func.func @maximum_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> // CHECK-LABEL: @argmax_nan_propagate func.func @argmax_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () { // CHECK: linalg.generic - // CHECK: arith.cmpf ogt + // CHECK: arith.cmpf ugt + // CHECK: arith.cmpf ord + // CHECK: andi // CHECK: arith.select // CHECK: arith.select // CHECK-NOT: arith.cmpf uno - // CHECK-NOT: arith.cmpf uno - // CHECK-NOT: arith.select // CHECK-NOT: arith.select // CHECK: linalg.yield %11 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<4xi32> @@ -2267,9 +2269,6 @@ func.func @argmax_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> // CHECK: arith.cmpf ogt // CHECK: arith.select // CHECK: arith.select - // CHECK: arith.cmpf uno - // CHECK: arith.select - // CHECK: arith.select // CHECK: linalg.yield %12 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<4xi32> return