@@ -275,7 +275,7 @@ static Value genPositionsCall(OpBuilder &builder, Location loc,
275
275
.getResult (0 );
276
276
}
277
277
278
- // / Generates a call to obtain the coordindates array.
278
+ // / Generates a call to obtain the coordinates array.
279
279
static Value genCoordinatesCall (OpBuilder &builder, Location loc,
280
280
SparseTensorType stt, Value ptr, Level l) {
281
281
Type crdTp = stt.getCrdType ();
@@ -287,6 +287,20 @@ static Value genCoordinatesCall(OpBuilder &builder, Location loc,
287
287
.getResult (0 );
288
288
}
289
289
290
+ // / Generates a call to obtain the coordinates array (AoS view).
291
+ static Value genCoordinatesBufferCall (OpBuilder &builder, Location loc,
292
+ SparseTensorType stt, Value ptr,
293
+ Level l) {
294
+ Type crdTp = stt.getCrdType ();
295
+ auto resTp = MemRefType::get ({ShapedType::kDynamic }, crdTp);
296
+ Value lvl = constantIndex (builder, loc, l);
297
+ SmallString<25 > name{" sparseCoordinatesBuffer" ,
298
+ overheadTypeFunctionSuffix (crdTp)};
299
+ return createFuncCall (builder, loc, name, resTp, {ptr, lvl},
300
+ EmitCInterface::On)
301
+ .getResult (0 );
302
+ }
303
+
290
304
// ===----------------------------------------------------------------------===//
291
305
// Conversion rules.
292
306
// ===----------------------------------------------------------------------===//
@@ -518,13 +532,35 @@ class SparseTensorToCoordinatesConverter
518
532
LogicalResult
519
533
matchAndRewrite (ToCoordinatesOp op, OpAdaptor adaptor,
520
534
ConversionPatternRewriter &rewriter) const override {
535
+ const Location loc = op.getLoc ();
536
+ auto stt = getSparseTensorType (op.getTensor ());
537
+ auto crds = genCoordinatesCall (rewriter, loc, stt, adaptor.getTensor (),
538
+ op.getLevel ());
539
+ // Cast the MemRef type to the type expected by the users, though these
540
+ // two types should be compatible at runtime.
541
+ if (op.getType () != crds.getType ())
542
+ crds = rewriter.create <memref::CastOp>(loc, op.getType (), crds);
543
+ rewriter.replaceOp (op, crds);
544
+ return success ();
545
+ }
546
+ };
547
+
548
+ // / Sparse conversion rule for coordinate accesses (AoS style).
549
+ class SparseToCoordinatesBufferConverter
550
+ : public OpConversionPattern<ToCoordinatesBufferOp> {
551
+ public:
552
+ using OpConversionPattern::OpConversionPattern;
553
+ LogicalResult
554
+ matchAndRewrite (ToCoordinatesBufferOp op, OpAdaptor adaptor,
555
+ ConversionPatternRewriter &rewriter) const override {
556
+ const Location loc = op.getLoc ();
521
557
auto stt = getSparseTensorType (op.getTensor ());
522
- auto crds = genCoordinatesCall (rewriter, op. getLoc (), stt,
523
- adaptor.getTensor (), op. getLevel ());
558
+ auto crds = genCoordinatesBufferCall (
559
+ rewriter, loc, stt, adaptor.getTensor (), stt. getAoSCOOStart ());
524
560
// Cast the MemRef type to the type expected by the users, though these
525
561
// two types should be compatible at runtime.
526
562
if (op.getType () != crds.getType ())
527
- crds = rewriter.create <memref::CastOp>(op. getLoc () , op.getType (), crds);
563
+ crds = rewriter.create <memref::CastOp>(loc , op.getType (), crds);
528
564
rewriter.replaceOp (op, crds);
529
565
return success ();
530
566
}
@@ -878,10 +914,10 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
878
914
SparseTensorAllocConverter, SparseTensorEmptyConverter,
879
915
SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
880
916
SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
881
- SparseTensorToValuesConverter, SparseNumberOfEntriesConverter ,
882
- SparseTensorLoadConverter, SparseTensorInsertConverter ,
883
- SparseTensorExpandConverter, SparseTensorCompressConverter ,
884
- SparseTensorAssembleConverter, SparseTensorDisassembleConverter ,
885
- SparseHasRuntimeLibraryConverter>(typeConverter,
886
- patterns.getContext ());
917
+ SparseToCoordinatesBufferConverter, SparseTensorToValuesConverter ,
918
+ SparseNumberOfEntriesConverter, SparseTensorLoadConverter ,
919
+ SparseTensorInsertConverter, SparseTensorExpandConverter ,
920
+ SparseTensorCompressConverter, SparseTensorAssembleConverter ,
921
+ SparseTensorDisassembleConverter, SparseHasRuntimeLibraryConverter>(
922
+ typeConverter, patterns.getContext ());
887
923
}
0 commit comments