Skip to content

Commit 097d2f1

Browse files
authored
[mlir][sparse] optimize memory load to SSA value when generating spar… (#74750)
…se conv kernel.
1 parent 3ed940a commit 097d2f1

File tree

3 files changed

+208
-247
lines changed

3 files changed

+208
-247
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,6 @@ static void updateSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf,
167167
Value pPtr) {
168168
builder.create<memref::StoreOp>(loc, pPtr, sPosBuf, C_IDX(1));
169169
}
170-
static Value loadSlicePosTupleNum(OpBuilder &builder, Location loc,
171-
Value sPosBuf) {
172-
return genIndexLoad(builder, loc, sPosBuf, C_IDX(0));
173-
}
174-
static void updateSlicePosTupleNum(OpBuilder &builder, Location loc, Value num,
175-
Value sPosBuf) {
176-
builder.create<memref::StoreOp>(loc, num, sPosBuf, C_IDX(0));
177-
}
178170

179171
// Gets and sets position values for slice-driven loops.
180172
enum class SlicePosKind { kLo, kHi, kNext };
@@ -405,7 +397,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
405397
sliceMeta[tid].assign(lvlRank, std::vector<std::pair<Value, unsigned>>());
406398
sliceStack[tid].emplace_back(/*minCrd=*/Value(),
407399
/*offset=*/Value(), /*isNonEmpty*/ Value(),
408-
std::nullopt, 0);
400+
/*posTupleNum=*/Value(), std::nullopt, 0);
409401
if (dimGetter && !isSynTensor(tid)) {
410402
for (Level l = 0; l < lvlRank; l++) {
411403
dependentLvlMap[tid][l] = dimGetter(tid, l);
@@ -1797,7 +1789,7 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
17971789
unsigned depth = frontSlice.depth - 1;
17981790
Value offset = frontSlice.offset;
17991791
Value sPtrBuf = slicePosBuffer[tid][firstLvl][depth];
1800-
Value mSz = loadSlicePosTupleNum(builder, loc, sPtrBuf);
1792+
Value mSz = frontSlice.posTupleNum;
18011793
outerMost = builder.create<scf::ForOp>(
18021794
loc, c0, mSz, c1, innerArgs,
18031795
[this, tid, firstLvl, offset, sPtrBuf, &ip, &pos,
@@ -1908,7 +1900,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
19081900
// Dense slice begin is trivial.
19091901
sliceStack[tid].emplace_back(/*minCoord=*/c0, /*offset=*/c0,
19101902
/*nonEmpty=*/constantI1(builder, loc, true),
1911-
lvl, /*depth=*/1);
1903+
c0, lvl, /*depth=*/1);
19121904
return;
19131905
}
19141906
auto [nxSz, stride] = sliceMeta[tid][lvl][1];
@@ -1924,12 +1916,13 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
19241916
pHi = genIndexLoad(builder, loc, positionsBuffers[tid][lvl],
19251917
ADDI(posits[tid][lvl - 1], c1));
19261918
}
1927-
// Fills out pIdxBuffer[tid][lvl][0] with [/*memSize =*/4, 0, pLo, pHi]
1928-
updateSlicePosTupleNum(builder, loc, c1, sPtrBuf);
1919+
// Fills out pIdxBuffer[tid][lvl][0] with [0, pLo, pHi]
19291920
updateSlicePosPtr(builder, loc, sPtrBuf, c0);
19301921
updateSlicePos(builder, loc, sPtrBuf, pLo, c0, SlicePosKind::kLo);
19311922
updateSlicePos(builder, loc, sPtrBuf, pHi, c0, SlicePosKind::kHi);
1932-
1923+
// Slice over a resolved parent, we only need one pair of pos hi and lo to
1924+
// specify the current slice.
1925+
Value tupleNum = c1;
19331926
// This is an non empty tensor if pLo < pHi.
19341927
Value isNonEmpty = CMPI(ult, pLo, pHi);
19351928
// The minimal coord must be at the first on ordered level.
@@ -1941,7 +1934,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
19411934

19421935
// FIXME: We need the relative offset related to the base slice.
19431936
Value absOffset = offsetFromMinCoord(builder, loc, minCrd, nxSz, isNonEmpty);
1944-
sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, lvl,
1937+
sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, tupleNum, lvl,
19451938
/*depth=*/1);
19461939
}
19471940

@@ -1973,8 +1966,8 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
19731966
Value remSz = sliceMeta[tid][lvl][depth + 1].first;
19741967
// Dense slice begin is trivial
19751968
if (isDenseLT(lvlTypes[tid][lvl])) {
1976-
sliceStack[tid].emplace_back(c0, c0, constantI1(builder, loc, false), lvl,
1977-
depth + 1);
1969+
sliceStack[tid].emplace_back(c0, c0, constantI1(builder, loc, false), c0,
1970+
lvl, depth + 1);
19781971
return;
19791972
}
19801973

@@ -2064,11 +2057,11 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
20642057
Value minCrd = result[1];
20652058
// Two metadata [memSize, idx].
20662059
// TODO: Can use an SSA value for these two metadata
2067-
updateSlicePosTupleNum(builder, loc, result[2], sPtrBuf);
20682060
updateSlicePosPtr(builder, loc, sPtrBuf, c0);
20692061
// FIXME: we need the relative offset related to the base slice.
20702062
Value absOffset = offsetFromMinCoord(builder, loc, minCrd, remSz, isNonEmpty);
2071-
sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, lvl, depth + 1);
2063+
sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, result[2], lvl,
2064+
depth + 1);
20722065
}
20732066

20742067
bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
@@ -2212,10 +2205,10 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
22122205
// offset = minCrd - size + 1;
22132206
// }
22142207
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
2215-
reduc[2] = absOffset; // restore value.
2216-
Value mSz = loadSlicePosTupleNum(builder, loc, sPtrBuf); // memSize
2217-
reduc[0] = lvlSizes[tid][lvl]; // next min coord
2218-
reduc[1] = constantI1(builder, loc, false); // isNonEmpty
2208+
reduc[2] = absOffset; // restore value.
2209+
Value mSz = info.posTupleNum; // tuple number.
2210+
reduc[0] = lvlSizes[tid][lvl]; // next min coord
2211+
reduc[1] = constantI1(builder, loc, false); // isNonEmpty
22192212
auto loopArgs = static_cast<ValueRange>(reduc).drop_back();
22202213
auto forOp = scf::buildLoopNest(
22212214
builder, loc, c0, mSz, c1, loopArgs,

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -298,20 +298,21 @@ class LoopEmitter {
298298
struct SliceInfo final {
299299
// Note that we do not need to create a actual sparse tensor slice but
300300
// instead only need to maintain the metadata of the slice.
301-
SliceInfo(Value minCrd, Value offset, Value isNonEmpty,
301+
SliceInfo(Value minCrd, Value offset, Value isNonEmpty, Value posTupleNum,
302302
std::optional<Level> slicedOnLvl, unsigned depth)
303303
: minCrd(minCrd), offset(offset), isNonEmpty(isNonEmpty),
304-
slicedOnLvl(slicedOnLvl), depth(depth) {
304+
posTupleNum(posTupleNum), slicedOnLvl(slicedOnLvl), depth(depth) {
305305
// TODO: use std::optional<pair<Level, minCrd>>
306306
assert(!slicedOnLvl || minCrd);
307307
}
308308

309309
// Whether this is the tensor that has not yet been sliced.
310310
bool isInitialTensor() const { return !slicedOnLvl.has_value(); }
311311

312-
Value minCrd; // the minimum coordinate of the slice.
313-
Value offset; // the *absolute* offset of the current slice.
314-
Value isNonEmpty; // whether the slice is empty.
312+
Value minCrd; // the minimum coordinate of the slice.
313+
Value offset; // the *absolute* offset of the current slice.
314+
Value isNonEmpty; // whether the slice is empty.
315+
Value posTupleNum; // The number of position tuples used in the slice.
315316
std::optional<Level> slicedOnLvl; // the level on which the slice is done
316317
unsigned depth; // the depth (relative to dependentDimMap[tid][lvl]).
317318
};
@@ -650,17 +651,6 @@ class LoopEmitter {
650651
std::vector<std::vector<LevelType>> lvlTypes;
651652
// Sparse iteration information for each `(TensorId, Level)` pair.
652653
// These arrays are updated to remain current within the current loop.
653-
// TODO: Clarify which of these are indexed by dstLvl vs srcLvl.
654-
//
655-
/// The collection of positions for a given element (one such collection
656-
/// for each tensor). This is the position analogue of the "coords"
657-
/// naming convention.
658-
///
659-
/// FIXME: [CLARIFY_POSITS_LVL] It's unclear which levels are used
660-
/// to index the `posits` array. On the one hand `genSparseCrd`
661-
/// uses dstLvl; on the other hand `enterLoopOverTensorAtLvl`,
662-
/// `prepareLoopOverTensorAtLvl`, and `enterCoIterationOverTensorsAtLvls`
663-
/// uses srcLvl. So which is it?
664654
std::vector<std::vector<Value>> posits;
665655
/// The collection of coordinates for a given element (one such
666656
/// collection for each tensor).
@@ -704,10 +694,6 @@ class LoopEmitter {
704694
// sliceStack[tid] holds the generated slice stack on tid.
705695
std::vector<std::vector<SliceInfo>> sliceStack;
706696

707-
/// TODO: not yet used, it should track the current level for each tensor
708-
/// to help eliminate `lvls` paramters from above APIs.
709-
/// std::vector<Level> curLvl;
710-
711697
//
712698
// Fields which have at most `numLoops` many entries.
713699
//

0 commit comments

Comments
 (0)