Skip to content

Commit 6071f6f

Browse files
committed
[mlir][sparse] Fix a problem in handling data type conversion.
Previously, the genCast function generates arith.trunci for converting f32 to i32. Fix the function to use mlir::convertScalarToDtype to correctly handle conversion cases beyond index casting. Add a test case for codegen the sparse_tensor.convert op. Reviewed By: aartbik, Peiming, wrengr Differential Revision: https://reviews.llvm.org/D147272
1 parent b24e290 commit 6071f6f

File tree

2 files changed

+19
-22
lines changed

2 files changed

+19
-22
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -208,28 +208,9 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
208208
if (srcTp.isa<IndexType>() || dstTp.isa<IndexType>())
209209
return builder.create<arith::IndexCastOp>(loc, dstTp, value);
210210

211-
const bool ext =
212-
srcTp.getIntOrFloatBitWidth() < dstTp.getIntOrFloatBitWidth();
213-
214-
// float => float.
215-
if (srcTp.isa<FloatType>() && dstTp.isa<FloatType>()) {
216-
if (ext)
217-
return builder.create<arith::ExtFOp>(loc, dstTp, value);
218-
return builder.create<arith::TruncFOp>(loc, dstTp, value);
219-
}
220-
221-
// int => int
222-
const auto srcIntTp = srcTp.dyn_cast<IntegerType>();
223-
if (srcIntTp && dstTp.isa<IntegerType>()) {
224-
if (!ext)
225-
return builder.create<arith::TruncIOp>(loc, dstTp, value);
226-
if (srcIntTp.isUnsigned())
227-
return builder.create<arith::ExtUIOp>(loc, dstTp, value);
228-
if (srcIntTp.isSigned())
229-
return builder.create<arith::ExtSIOp>(loc, dstTp, value);
230-
}
231-
232-
llvm_unreachable("unhandled type casting");
211+
const auto srcIntTp = srcTp.dyn_cast_or_null<IntegerType>();
212+
const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false;
213+
return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
233214
}
234215

235216
mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {

mlir/test/Dialect/SparseTensor/codegen.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,22 @@ func.func @sparse_nop_convert(%arg0: tensor<32xf32, #SparseVector>) -> tensor<?x
663663
return %0 : tensor<?xf32, #SparseVector>
664664
}
665665

666+
// CHECK-LABEL: func.func @sparse_convert_element_type(
667+
// CHECK-SAME: %[[A1:.*]]: memref<?xi32>,
668+
// CHECK-SAME: %[[A2:.*]]: memref<?xi64>,
669+
// CHECK-SAME: %[[A3:.*]]: memref<?xf32>,
670+
// CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier
671+
// CHECK: scf.for
672+
// CHECK: %[[FValue:.*]] = memref.load
673+
// CHECK: %[[IValue:.*]] = arith.fptosi %[[FValue]]
674+
// CHECK: memref.store %[[IValue]]
675+
// CHECK: return %{{.*}}, %{{.*}}, %{{.*}}, %[[A4]] :
676+
// CHECK-SAME: memref<?xi32>, memref<?xi64>, memref<?xi32>, !sparse_tensor.storage_specifier
677+
func.func @sparse_convert_element_type(%arg0: tensor<32xf32, #SparseVector>) -> tensor<?xi32, #SparseVector> {
678+
%0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor<?xi32, #SparseVector>
679+
return %0 : tensor<?xi32, #SparseVector>
680+
}
681+
666682
// CHECK-LABEL: func.func @sparse_new_coo(
667683
// CHECK-SAME: %[[A0:.*]]: !llvm.ptr<i8>) -> (memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{ dimLevelType = [ "compressed", "singleton" ] }>>) {
668684
// CHECK-DAG: %[[A1:.*]] = arith.constant false

0 commit comments

Comments
 (0)