Skip to content

Commit dc4cfdb

Browse files
authored
[mlir][sparse] provide an AoS "view" into sparse runtime support lib (#87116)
Note that even though the sparse runtime support lib always uses SoA storage for COO storage (and provides correct codegen by means of views into this storage), in some rare cases we need the true physical SoA storage as a coordinate buffer. This PR provides that functionality by means of a (costly) coordinate buffer call. Since this is currently only used for testing/debugging by means of the sparse_tensor.print method, this solution is acceptable. If we ever want a performing version of this, we should truly support AoS storage of COO in addition to the SoA used right now.
1 parent 038e66f commit dc4cfdb

File tree

7 files changed

+152
-23
lines changed

7 files changed

+152
-23
lines changed

mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h

+35
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,12 @@ class SparseTensorStorageBase {
143143
MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DECL_GETCOORDINATES)
144144
#undef DECL_GETCOORDINATES
145145

146+
/// Gets coordinates-overhead storage buffer for the given level.
147+
#define DECL_GETCOORDINATESBUFFER(INAME, C) \
148+
virtual void getCoordinatesBuffer(std::vector<C> **, uint64_t);
149+
MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DECL_GETCOORDINATESBUFFER)
150+
#undef DECL_GETCOORDINATESBUFFER
151+
146152
/// Gets primary storage.
147153
#define DECL_GETVALUES(VNAME, V) virtual void getValues(std::vector<V> **);
148154
MLIR_SPARSETENSOR_FOREVERY_V(DECL_GETVALUES)
@@ -251,6 +257,31 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
251257
assert(lvl < getLvlRank());
252258
*out = &coordinates[lvl];
253259
}
260+
void getCoordinatesBuffer(std::vector<C> **out, uint64_t lvl) final {
261+
assert(out && "Received nullptr for out parameter");
262+
assert(lvl < getLvlRank());
263+
// Note that the sparse tensor support library always stores COO in SoA
264+
// format, even when AoS is requested. This is never an issue, since all
265+
// actual code/library generation requests "views" into the coordinate
266+
// storage for the individual levels, which is trivially provided for
267+
// both AoS and SoA (as well as all the other storage formats). The only
268+
// exception is when the buffer version of coordinate storage is requested
269+
// (currently only for printing). In that case, we do the following
270+
// potentially expensive transformation to provide that view. If this
271+
// operation becomes more common beyond debugging, we should consider
272+
// implementing proper AoS in the support library as well.
273+
uint64_t lvlRank = getLvlRank();
274+
uint64_t nnz = values.size();
275+
crdBuffer.clear();
276+
crdBuffer.reserve(nnz * (lvlRank - lvl));
277+
for (uint64_t i = 0; i < nnz; i++) {
278+
for (uint64_t l = lvl; l < lvlRank; l++) {
279+
assert(i < coordinates[l].size());
280+
crdBuffer.push_back(coordinates[l][i]);
281+
}
282+
}
283+
*out = &crdBuffer;
284+
}
254285
void getValues(std::vector<V> **out) final {
255286
assert(out && "Received nullptr for out parameter");
256287
*out = &values;
@@ -529,10 +560,14 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
529560
return -1u;
530561
}
531562

563+
// Sparse tensor storage components.
532564
std::vector<std::vector<P>> positions;
533565
std::vector<std::vector<C>> coordinates;
534566
std::vector<V> values;
567+
568+
// Auxiliary data structures.
535569
std::vector<uint64_t> lvlCursor;
570+
std::vector<C> crdBuffer; // just for AoS view
536571
};
537572

538573
//===----------------------------------------------------------------------===//

mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h

+8
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@ MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSEPOSITIONS)
7777
MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSECOORDINATES)
7878
#undef DECL_SPARSECOORDINATES
7979

80+
/// Tensor-storage method to obtain direct access to the coordinates array
81+
/// buffer for the given level (provides an AoS view into the library).
82+
#define DECL_SPARSECOORDINATES(CNAME, C) \
83+
MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_sparseCoordinatesBuffer##CNAME( \
84+
StridedMemRefType<C, 1> *out, void *tensor, index_type lvl);
85+
MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSECOORDINATES)
86+
#undef DECL_SPARSECOORDINATES
87+
8088
/// Tensor-storage method to insert elements in lexicographical
8189
/// level-coordinate order.
8290
#define DECL_LEXINSERT(VNAME, V) \

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

+46-10
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ static Value genPositionsCall(OpBuilder &builder, Location loc,
275275
.getResult(0);
276276
}
277277

278-
/// Generates a call to obtain the coordindates array.
278+
/// Generates a call to obtain the coordinates array.
279279
static Value genCoordinatesCall(OpBuilder &builder, Location loc,
280280
SparseTensorType stt, Value ptr, Level l) {
281281
Type crdTp = stt.getCrdType();
@@ -287,6 +287,20 @@ static Value genCoordinatesCall(OpBuilder &builder, Location loc,
287287
.getResult(0);
288288
}
289289

290+
/// Generates a call to obtain the coordinates array (AoS view).
291+
static Value genCoordinatesBufferCall(OpBuilder &builder, Location loc,
292+
SparseTensorType stt, Value ptr,
293+
Level l) {
294+
Type crdTp = stt.getCrdType();
295+
auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
296+
Value lvl = constantIndex(builder, loc, l);
297+
SmallString<25> name{"sparseCoordinatesBuffer",
298+
overheadTypeFunctionSuffix(crdTp)};
299+
return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
300+
EmitCInterface::On)
301+
.getResult(0);
302+
}
303+
290304
//===----------------------------------------------------------------------===//
291305
// Conversion rules.
292306
//===----------------------------------------------------------------------===//
@@ -518,13 +532,35 @@ class SparseTensorToCoordinatesConverter
518532
LogicalResult
519533
matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
520534
ConversionPatternRewriter &rewriter) const override {
535+
const Location loc = op.getLoc();
536+
auto stt = getSparseTensorType(op.getTensor());
537+
auto crds = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
538+
op.getLevel());
539+
// Cast the MemRef type to the type expected by the users, though these
540+
// two types should be compatible at runtime.
541+
if (op.getType() != crds.getType())
542+
crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
543+
rewriter.replaceOp(op, crds);
544+
return success();
545+
}
546+
};
547+
548+
/// Sparse conversion rule for coordinate accesses (AoS style).
549+
class SparseToCoordinatesBufferConverter
550+
: public OpConversionPattern<ToCoordinatesBufferOp> {
551+
public:
552+
using OpConversionPattern::OpConversionPattern;
553+
LogicalResult
554+
matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
555+
ConversionPatternRewriter &rewriter) const override {
556+
const Location loc = op.getLoc();
521557
auto stt = getSparseTensorType(op.getTensor());
522-
auto crds = genCoordinatesCall(rewriter, op.getLoc(), stt,
523-
adaptor.getTensor(), op.getLevel());
558+
auto crds = genCoordinatesBufferCall(
559+
rewriter, loc, stt, adaptor.getTensor(), stt.getAoSCOOStart());
524560
// Cast the MemRef type to the type expected by the users, though these
525561
// two types should be compatible at runtime.
526562
if (op.getType() != crds.getType())
527-
crds = rewriter.create<memref::CastOp>(op.getLoc(), op.getType(), crds);
563+
crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
528564
rewriter.replaceOp(op, crds);
529565
return success();
530566
}
@@ -878,10 +914,10 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
878914
SparseTensorAllocConverter, SparseTensorEmptyConverter,
879915
SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
880916
SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
881-
SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
882-
SparseTensorLoadConverter, SparseTensorInsertConverter,
883-
SparseTensorExpandConverter, SparseTensorCompressConverter,
884-
SparseTensorAssembleConverter, SparseTensorDisassembleConverter,
885-
SparseHasRuntimeLibraryConverter>(typeConverter,
886-
patterns.getContext());
917+
SparseToCoordinatesBufferConverter, SparseTensorToValuesConverter,
918+
SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
919+
SparseTensorInsertConverter, SparseTensorExpandConverter,
920+
SparseTensorCompressConverter, SparseTensorAssembleConverter,
921+
SparseTensorDisassembleConverter, SparseHasRuntimeLibraryConverter>(
922+
typeConverter, patterns.getContext());
887923
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,9 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
648648
loc, lvl, vector::PrintPunctuation::NoPunctuation);
649649
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
650650
Value crd = nullptr;
651-
// TODO: eliminates ToCoordinateBufferOp!
651+
// For COO AoS storage, we want to print a single, linear view of
652+
// the full coordinate storage at this level. For any other storage,
653+
// we show the coordinate storage for every indivual level.
652654
if (stt.getAoSCOOStart() == l)
653655
crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
654656
else

mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETPOSITIONS)
6868
MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATES)
6969
#undef IMPL_GETCOORDINATES
7070

71+
#define IMPL_GETCOORDINATESBUFFER(CNAME, C) \
72+
void SparseTensorStorageBase::getCoordinatesBuffer(std::vector<C> **, \
73+
uint64_t) { \
74+
FATAL_PIV("getCoordinatesBuffer" #CNAME); \
75+
}
76+
MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATESBUFFER)
77+
#undef IMPL_GETCOORDINATESBUFFER
78+
7179
#define IMPL_GETVALUES(VNAME, V) \
7280
void SparseTensorStorageBase::getValues(std::vector<V> **) { \
7381
FATAL_PIV("getValues" #VNAME); \

mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES)
311311
assert(v); \
312312
aliasIntoMemref(v->size(), v->data(), *ref); \
313313
}
314+
314315
#define IMPL_SPARSEPOSITIONS(PNAME, P) \
315316
IMPL_GETOVERHEAD(sparsePositions##PNAME, P, getPositions)
316317
MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEPOSITIONS)
@@ -320,6 +321,12 @@ MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEPOSITIONS)
320321
IMPL_GETOVERHEAD(sparseCoordinates##CNAME, C, getCoordinates)
321322
MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES)
322323
#undef IMPL_SPARSECOORDINATES
324+
325+
#define IMPL_SPARSECOORDINATESBUFFER(CNAME, C) \
326+
IMPL_GETOVERHEAD(sparseCoordinatesBuffer##CNAME, C, getCoordinatesBuffer)
327+
MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATESBUFFER)
328+
#undef IMPL_SPARSECOORDINATESBUFFER
329+
323330
#undef IMPL_GETOVERHEAD
324331

325332
#define IMPL_LEXINSERT(VNAME, V) \

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir

+45-12
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,14 @@
120120
)
121121
}>
122122

123+
#COOAoS = #sparse_tensor.encoding<{
124+
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
125+
}>
126+
127+
#COOSoA = #sparse_tensor.encoding<{
128+
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa))
129+
}>
130+
123131
module {
124132

125133
//
@@ -161,6 +169,8 @@ module {
161169
%h = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSCC>
162170
%i = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSR0>
163171
%j = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSC0>
172+
%AoS = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #COOAoS>
173+
%SoA = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #COOSoA>
164174

165175
// CHECK-NEXT: ---- Sparse Tensor ----
166176
// CHECK-NEXT: nse = 5
@@ -274,19 +284,42 @@ module {
274284
// CHECK-NEXT: ----
275285
sparse_tensor.print %j : tensor<4x8xi32, #BSC0>
276286

287+
// CHECK-NEXT: ---- Sparse Tensor ----
288+
// CHECK-NEXT: nse = 5
289+
// CHECK-NEXT: dim = ( 4, 8 )
290+
// CHECK-NEXT: lvl = ( 4, 8 )
291+
// CHECK-NEXT: pos[0] : ( 0, 5,
292+
// CHECK-NEXT: crd[0] : ( 0, 0, 0, 2, 3, 2, 3, 3, 3, 5,
293+
// CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
294+
// CHECK-NEXT: ----
295+
sparse_tensor.print %AoS : tensor<4x8xi32, #COOAoS>
296+
297+
// CHECK-NEXT: ---- Sparse Tensor ----
298+
// CHECK-NEXT: nse = 5
299+
// CHECK-NEXT: dim = ( 4, 8 )
300+
// CHECK-NEXT: lvl = ( 4, 8 )
301+
// CHECK-NEXT: pos[0] : ( 0, 5,
302+
// CHECK-NEXT: crd[0] : ( 0, 0, 3, 3, 3,
303+
// CHECK-NEXT: crd[1] : ( 0, 2, 2, 3, 5,
304+
// CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
305+
// CHECK-NEXT: ----
306+
sparse_tensor.print %SoA : tensor<4x8xi32, #COOSoA>
307+
277308
// Release the resources.
278-
bufferization.dealloc_tensor %XO : tensor<4x8xi32, #AllDense>
279-
bufferization.dealloc_tensor %XT : tensor<4x8xi32, #AllDenseT>
280-
bufferization.dealloc_tensor %a : tensor<4x8xi32, #CSR>
281-
bufferization.dealloc_tensor %b : tensor<4x8xi32, #DCSR>
282-
bufferization.dealloc_tensor %c : tensor<4x8xi32, #CSC>
283-
bufferization.dealloc_tensor %d : tensor<4x8xi32, #DCSC>
284-
bufferization.dealloc_tensor %e : tensor<4x8xi32, #BSR>
285-
bufferization.dealloc_tensor %f : tensor<4x8xi32, #BSRC>
286-
bufferization.dealloc_tensor %g : tensor<4x8xi32, #BSC>
287-
bufferization.dealloc_tensor %h : tensor<4x8xi32, #BSCC>
288-
bufferization.dealloc_tensor %i : tensor<4x8xi32, #BSR0>
289-
bufferization.dealloc_tensor %j : tensor<4x8xi32, #BSC0>
309+
bufferization.dealloc_tensor %XO : tensor<4x8xi32, #AllDense>
310+
bufferization.dealloc_tensor %XT : tensor<4x8xi32, #AllDenseT>
311+
bufferization.dealloc_tensor %a : tensor<4x8xi32, #CSR>
312+
bufferization.dealloc_tensor %b : tensor<4x8xi32, #DCSR>
313+
bufferization.dealloc_tensor %c : tensor<4x8xi32, #CSC>
314+
bufferization.dealloc_tensor %d : tensor<4x8xi32, #DCSC>
315+
bufferization.dealloc_tensor %e : tensor<4x8xi32, #BSR>
316+
bufferization.dealloc_tensor %f : tensor<4x8xi32, #BSRC>
317+
bufferization.dealloc_tensor %g : tensor<4x8xi32, #BSC>
318+
bufferization.dealloc_tensor %h : tensor<4x8xi32, #BSCC>
319+
bufferization.dealloc_tensor %i : tensor<4x8xi32, #BSR0>
320+
bufferization.dealloc_tensor %j : tensor<4x8xi32, #BSC0>
321+
bufferization.dealloc_tensor %AoS : tensor<4x8xi32, #COOAoS>
322+
bufferization.dealloc_tensor %SoA : tensor<4x8xi32, #COOSoA>
290323

291324
return
292325
}

0 commit comments

Comments
 (0)