Skip to content

Commit f8f0a1e

Browse files
Ettore Tiottotungld
Ettore Tiotto
andauthored
Use builder based interface to generate Krnl loops (llvm#1250)
* [ClipOp]: Use builder based interface to generate Krnl loops Signed-off-by: Ettore Tiotto <[email protected]> * [TileOp]: Use builder based interface to generate Krnl loops Signed-off-by: Ettore Tiotto <[email protected]> * [Transpose]: Use builder based interface to generate Krnl loops Signed-off-by: Ettore Tiotto <[email protected]> * [Transpose]: Remove #ifdef out code Signed-off-by: Ettore Tiotto <[email protected]> * [Split]: Use builder based interface to generate Krnl loops Signed-off-by: Ettore Tiotto <[email protected]> * Address code review comments Signed-off-by: Ettore Tiotto <[email protected]> Co-authored-by: Tung D. Le <[email protected]>
1 parent bbda7b2 commit f8f0a1e

File tree

10 files changed

+380
-299
lines changed

10 files changed

+380
-299
lines changed

src/Conversion/ONNXToKrnl/Math/Clip.cpp

+54-60
Original file line numberDiff line numberDiff line change
@@ -28,73 +28,67 @@ struct ONNXClipOpLowering : public ConversionPattern {
2828
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
2929
ConversionPatternRewriter &rewriter) const final {
3030
Location loc = op->getLoc();
31-
Value input = operands[0];
32-
Value min = operands[1];
33-
Value max = operands[2];
31+
ONNXClipOp clipOp = cast<ONNXClipOp>(op);
32+
MemRefType memRefType = convertToMemRefType(*op->result_type_begin());
3433

35-
// Insert an allocation and deallocation for the result of this operation.
36-
auto memRefType = convertToMemRefType(*op->result_type_begin());
37-
38-
Value alloc;
39-
bool insertDealloc = checkInsertDealloc(op);
34+
ONNXClipOpAdaptor operandAdaptor(operands);
35+
ONNXClipOpShapeHelper shapeHelper(&clipOp, &rewriter,
36+
getDenseElementAttributeFromKrnlValue,
37+
loadDenseElementArrayValueAtIndex);
38+
auto shapeComputed = shapeHelper.computeShape(operandAdaptor);
39+
assert(succeeded(shapeComputed));
4040

41-
if (hasAllConstantDimensions(memRefType))
42-
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
43-
else
44-
alloc = insertAllocAndDealloc(
45-
memRefType, loc, rewriter, insertDealloc, input);
41+
Value input = operandAdaptor.input();
42+
Value min = operandAdaptor.min();
43+
Value max = operandAdaptor.max();
4644

47-
SmallVector<Value, 4> loopIVs;
48-
// Only create krnl.iterate if one of the operands is not scalar tensor.
45+
// Insert an allocation and deallocation for the result of this operation.
46+
bool insertDealloc = checkInsertDealloc(op);
47+
Value alloc =
48+
(hasAllConstantDimensions(memRefType))
49+
? insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc)
50+
: insertAllocAndDealloc(
51+
memRefType, loc, rewriter, insertDealloc, input);
52+
53+
auto computeResult =
54+
[&](MultiDialectBuilder<KrnlBuilder, MathBuilder> &create,
55+
const ValueRange &indices) {
56+
Value loadedVal = create.krnl.load(input, indices);
57+
Value res = loadedVal;
58+
if (!min.getType().isa<NoneType>()) {
59+
Value minVal = create.krnl.load(min);
60+
Value lessThanMin = create.math.slt(res, minVal);
61+
res = create.math.select(lessThanMin, minVal, res);
62+
}
63+
if (!max.getType().isa<NoneType>()) {
64+
Value maxVal = create.krnl.load(max);
65+
Value lessThanMax = create.math.slt(res, maxVal);
66+
res = create.math.select(lessThanMax, res, maxVal);
67+
}
68+
create.krnl.store(res, alloc, indices);
69+
};
70+
71+
// Create a loop only is one of the operands is not a scalar tensor.
4972
if (!hasAllScalarValues(operands)) {
50-
// Create iterateOp & get block within iterate op.
51-
BuildKrnlLoop loops(rewriter, loc, memRefType.getRank());
52-
loops.createDefineAndIterateOp(input);
53-
Block *iterationBlock = loops.getIterateBlock();
54-
55-
// Insert instructions inside the KernelIterateOp body.
56-
rewriter.setInsertionPointToStart(iterationBlock);
57-
58-
// Handle the operation:
59-
for (auto arg : iterationBlock->getArguments())
60-
loopIVs.push_back(arg);
61-
}
62-
63-
// Load unary first operand.
64-
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(rewriter, loc);
65-
Value loadedVal = create.krnl.load(input, loopIVs);
66-
Type inputType = loadedVal.getType();
67-
Value res = loadedVal;
68-
69-
if (inputType.isa<FloatType>()) {
70-
if (!min.getType().isa<NoneType>()) {
71-
Value minVal = create.krnl.load(min);
72-
Value lessThanMin = create.math.slt(res, minVal);
73-
res = create.math.select(lessThanMin, minVal, res);
74-
}
75-
if (!max.getType().isa<NoneType>()) {
76-
Value maxVal = create.krnl.load(max);
77-
Value lessThanMax = create.math.slt(res, maxVal);
78-
res = create.math.select(lessThanMax, res, maxVal);
79-
}
80-
} else if (inputType.isa<IntegerType>()) {
81-
if (!min.getType().isa<NoneType>()) {
82-
Value minVal = create.krnl.load(min);
83-
Value lessThanMin = create.math.slt(res, minVal);
84-
res = create.math.select(lessThanMin, minVal, res);
85-
}
86-
if (!max.getType().isa<NoneType>()) {
87-
Value maxVal = create.krnl.load(max);
88-
Value lessThanMax = create.math.slt(res, maxVal);
89-
res = create.math.select(lessThanMax, res, maxVal);
90-
}
73+
KrnlBuilder createKrnl(rewriter, loc);
74+
uint64_t numLoops = memRefType.getRank();
75+
ValueRange loopDef = createKrnl.defineLoops(numLoops);
76+
77+
SmallVector<IndexExpr, 4> lbs(numLoops, LiteralIndexExpr(0));
78+
SmallVector<IndexExpr, 4> ubs;
79+
for (uint64_t i = 0; i < numLoops; ++i)
80+
ubs.emplace_back(shapeHelper.dimsForOutput()[i]);
81+
82+
createKrnl.iterateIE(loopDef, loopDef, lbs, ubs,
83+
[&](KrnlBuilder &createKrnl, ValueRange indices) {
84+
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(createKrnl);
85+
computeResult(create, indices);
86+
});
9187
} else {
92-
llvm_unreachable("unsupported element type");
88+
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(rewriter, loc);
89+
computeResult(create, {});
9390
}
9491

95-
// Store result in the resulting array.
96-
create.krnl.store(res, alloc, loopIVs);
97-
9892
rewriter.replaceOp(op, alloc);
9993
return success();
10094
}

src/Conversion/ONNXToKrnl/Tensor/Split.cpp

+37-32
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ template <typename Adaptor, typename Op, typename ShapeHelper>
2121
LogicalResult ONNXSplitOpLoweringCommon(Operation *op, ArrayRef<Value> operands,
2222
ConversionPatternRewriter &rewriter) {
2323
// Gather info.
24-
auto loc = op->getLoc();
24+
Location loc = op->getLoc();
2525
Adaptor operandAdaptor(operands);
26-
Op splitOp = llvm::dyn_cast<Op>(op);
27-
auto rank = splitOp.input().getType().template cast<ShapedType>().getRank();
28-
auto outputNum = splitOp.getNumResults();
29-
auto axis = splitOp.axis();
26+
Op splitOp = cast<Op>(op);
27+
uint64_t rank =
28+
splitOp.input().getType().template cast<ShapedType>().getRank();
29+
unsigned outputNum = splitOp.getNumResults();
30+
unsigned axis = splitOp.axis();
3031

3132
// Get a shape helper.
3233
ShapeHelper shapeHelper(&splitOp, &rewriter,
@@ -36,7 +37,7 @@ LogicalResult ONNXSplitOpLoweringCommon(Operation *op, ArrayRef<Value> operands,
3637

3738
// Alloc and dealloc.
3839
SmallVector<Value, 4> allocs;
39-
for (unsigned int i = 0; i < outputNum; ++i) {
40+
for (unsigned i = 0; i < outputNum; ++i) {
4041
checkInsertDealloc(op, i);
4142
auto memRefType = convertToMemRefType(splitOp.outputs()[i].getType());
4243
Value alloc = insertAllocAndDeallocSimple(
@@ -45,40 +46,44 @@ LogicalResult ONNXSplitOpLoweringCommon(Operation *op, ArrayRef<Value> operands,
4546
}
4647

4748
// Creates loops, one for each output.
48-
for (unsigned int i = 0; i < outputNum; ++i) {
49+
for (unsigned i = 0; i < outputNum; ++i) {
4950
OpBuilder::InsertionGuard insertGuard(rewriter);
50-
// Create loop.
51-
BuildKrnlLoop outputLoops(rewriter, loc, rank);
52-
outputLoops.createDefineAndIterateOp(allocs[i]);
53-
rewriter.setInsertionPointToStart(outputLoops.getIterateBlock());
5451

5552
// Scope for krnl ops
5653
IndexExprScope childScope(&rewriter, shapeHelper.scope);
54+
5755
KrnlBuilder createKrnl(rewriter, loc);
56+
ValueRange loopDef = createKrnl.defineLoops(rank);
57+
SmallVector<IndexExpr, 4> lbs(rank, LiteralIndexExpr(0));
58+
59+
MemRefBoundsIndexCapture allocsBounds(allocs[i]);
60+
SmallVector<IndexExpr, 4> ubs;
61+
allocsBounds.getDimList(ubs);
62+
63+
createKrnl.iterateIE(loopDef, loopDef, lbs, ubs,
64+
[&](KrnlBuilder &createKrnl, ValueRange indices) {
65+
SmallVector<IndexExpr, 4> readIndices;
66+
for (uint64_t r = 0; r < rank; ++r) {
67+
DimIndexExpr readIndex(indices[r]);
68+
// Compute read index for the split axis.
69+
if (r == axis)
70+
for (unsigned k = 0; k < i; ++k) {
71+
SymbolIndexExpr splitDim(shapeHelper.dimsForOutput(k)[r]);
72+
readIndex = readIndex + splitDim;
73+
}
5874

59-
// Indices for the read and write.
60-
SmallVector<IndexExpr, 4> readIndices;
61-
SmallVector<IndexExpr, 4> writeIndices;
62-
for (int r = 0; r < rank; ++r) {
63-
Value readVal = outputLoops.getInductionVar(r);
64-
// If not the split axis, same index for read and write
65-
IndexExpr readIndex = DimIndexExpr(readVal);
66-
DimIndexExpr writeIndex(readVal);
67-
// If the split axis, compute read index for the split axis.
68-
if (r == axis) {
69-
for (unsigned int k = 0; k < i; ++k) {
70-
IndexExpr splitDim = SymbolIndexExpr(shapeHelper.dimsForOutput(k)[r]);
71-
readIndex = readIndex + splitDim;
72-
}
73-
}
74-
readIndices.emplace_back(readIndex);
75-
writeIndices.emplace_back(writeIndex);
76-
}
77-
// Insert copy.
78-
Value loadData = createKrnl.loadIE(operandAdaptor.input(), readIndices);
79-
createKrnl.storeIE(loadData, allocs[i], writeIndices);
75+
readIndices.emplace_back(readIndex);
76+
}
77+
78+
// Insert copy.
79+
Value loadData =
80+
createKrnl.loadIE(operandAdaptor.input(), readIndices);
81+
createKrnl.store(loadData, allocs[i], indices);
82+
});
8083
}
84+
8185
rewriter.replaceOp(op, allocs);
86+
8287
return success();
8388
}
8489

src/Conversion/ONNXToKrnl/Tensor/Tile.cpp

+26-39
Original file line numberDiff line numberDiff line change
@@ -71,48 +71,35 @@ struct ONNXTileOpLowering : public ConversionPattern {
7171
(void)shapecomputed;
7272
assert(!failed(shapecomputed) && "expected to succeed");
7373

74-
MemRefType outputMemRefType = convertToMemRefType(*op->result_type_begin());
75-
auto outputMemRefShape = outputMemRefType.getShape();
76-
int64_t outputRank = outputMemRefShape.size();
74+
MemRefType memRefType = convertToMemRefType(*op->result_type_begin());
75+
llvm::ArrayRef<int64_t> memRefShape = memRefType.getShape();
76+
uint64_t outputRank = memRefShape.size();
7777

7878
Value input = operandAdaptor.input();
79-
8079
Value alloc = insertAllocAndDeallocSimple(
81-
rewriter, op, outputMemRefType, loc, shapeHelper.dimsForOutput(0));
82-
83-
// Define loops and iteration trip counts (equivalent to size of output)
84-
BuildKrnlLoop outputLoops(rewriter, loc, outputRank);
85-
outputLoops.createDefineOp();
86-
outputLoops.pushAllBounds(shapeHelper.dimsForOutput(0));
87-
outputLoops.createIterateOp();
88-
rewriter.setInsertionPointToStart(outputLoops.getIterateBlock());
89-
90-
SmallVector<Value, 4> loadIndices;
91-
// This implementation is to iterate the output tensor.
92-
// The store has simple affine subscript expression.
93-
// Alternative implementation is to iterate the input tensor and repeats.
94-
// The load of elements in input tensor can be reused explicitly.
95-
// But the subscript of store is not contigious, or even not affine.
96-
// Alternative implementation can be found at the end of this file.
97-
98-
for (int64_t i = 0; i < outputRank; i++) {
99-
// Scope is created for each dimension because they are independent
100-
IndexExprScope IEScope(&rewriter, loc);
101-
DimIndexExpr index(outputLoops.getInductionVar(i));
102-
MemRefBoundsIndexCapture inputBounds(input);
103-
DimIndexExpr dimSize(inputBounds.getDim(i));
104-
IndexExpr exprVal = index % dimSize;
105-
loadIndices.emplace_back(exprVal.getValue());
106-
}
107-
108-
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(
109-
rewriter, loc);
110-
Value loadVal = create.krnl.load(input, loadIndices);
111-
112-
SmallVector<Value, 4> storeIndices;
113-
for (int64_t i = 0; i < outputRank; ++i)
114-
storeIndices.emplace_back(outputLoops.getInductionVar(i));
115-
create.krnl.store(loadVal, alloc, storeIndices);
80+
rewriter, op, memRefType, loc, shapeHelper.dimsForOutput());
81+
82+
KrnlBuilder createKrnl(rewriter, loc);
83+
ValueRange loopDef = createKrnl.defineLoops(outputRank);
84+
SmallVector<IndexExpr, 4> lbs(outputRank, LiteralIndexExpr(0));
85+
86+
MemRefBoundsIndexCapture inputBounds(input);
87+
createKrnl.iterateIE(loopDef, loopDef, lbs, shapeHelper.dimsForOutput(),
88+
[&](KrnlBuilder &createKrnl, ValueRange indices) {
89+
// Compute the indices used by the input tensor load operation.
90+
// Note: An alternative implementation can be found at the end of this
91+
// file.
92+
SmallVector<Value, 4> loadIndices;
93+
for (uint64_t i = 0; i < outputRank; ++i) {
94+
DimIndexExpr index(indices[i]);
95+
DimIndexExpr dimSize(inputBounds.getDim(i));
96+
IndexExpr exprVal = index % dimSize;
97+
loadIndices.emplace_back(exprVal.getValue());
98+
}
99+
100+
Value loadVal = createKrnl.load(input, loadIndices);
101+
createKrnl.store(loadVal, alloc, indices);
102+
});
116103

117104
rewriter.replaceOp(op, alloc);
118105

src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp

+23-27
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
3434

3535
// Basic information.
3636
auto memRefType = convertToMemRefType(*op->result_type_begin());
37-
int64_t rank = memRefType.getShape().size();
37+
uint64_t rank = memRefType.getShape().size();
3838

3939
// Get a shape helper.
4040
ONNXTransposeOpShapeHelper shapeHelper(&transposeOp, &rewriter,
@@ -46,32 +46,28 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
4646

4747
// Insert an allocation and deallocation for the result of this operation.
4848
Value alloc = insertAllocAndDeallocSimple(
49-
rewriter, op, memRefType, loc, shapeHelper.dimsForOutput(0));
50-
51-
// Create loop.
52-
BuildKrnlLoop inputLoops(rewriter, loc, rank);
53-
inputLoops.createDefineAndIterateOp(data);
54-
rewriter.setInsertionPointToStart(inputLoops.getIterateBlock());
55-
{
56-
// Get a child IndexExpr context.
57-
IndexExprScope childScope(&rewriter, shapeHelper.scope);
58-
KrnlBuilder createKrnl(rewriter, loc);
59-
60-
// Get read/write indices.
61-
SmallVector<IndexExpr, 4> readIndices;
62-
SmallVector<IndexExpr, 4> writeIndices;
63-
for (decltype(rank) i = 0; i < rank; ++i) {
64-
Value readVal = inputLoops.getInductionVar(i);
65-
Value writeVal =
66-
inputLoops.getInductionVar(ArrayAttrIntVal(permAttr, i));
67-
readIndices.emplace_back(DimIndexExpr(readVal));
68-
writeIndices.emplace_back(DimIndexExpr(writeVal));
69-
}
70-
71-
// Copy data.
72-
Value loadData = createKrnl.loadIE(data, readIndices);
73-
createKrnl.storeIE(loadData, alloc, writeIndices);
74-
}
49+
rewriter, op, memRefType, loc, shapeHelper.dimsForOutput());
50+
51+
KrnlBuilder createKrnl(rewriter, loc);
52+
ValueRange loopDef = createKrnl.defineLoops(rank);
53+
SmallVector<IndexExpr, 4> lbs(rank, LiteralIndexExpr(0));
54+
55+
MemRefBoundsIndexCapture dataBounds(data);
56+
SmallVector<IndexExpr, 4> ubs;
57+
dataBounds.getDimList(ubs);
58+
59+
createKrnl.iterateIE(loopDef, loopDef, lbs, ubs,
60+
[&](KrnlBuilder &createKrnl, ValueRange indices) {
61+
// Compute the indices used by the load operation.
62+
SmallVector<IndexExpr, 4> storeIndices;
63+
for (uint64_t i = 0; i < rank; ++i) {
64+
Value index = indices[ArrayAttrIntVal(permAttr, i)];
65+
storeIndices.emplace_back(DimIndexExpr(index));
66+
}
67+
68+
Value loadData = createKrnl.load(data, indices);
69+
createKrnl.storeIE(loadData, alloc, storeIndices);
70+
});
7571

7672
rewriter.replaceOp(op, alloc);
7773

src/Dialect/ONNX/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ add_onnx_mlir_library(OMONNXOps
2424
ShapeInference/ArgMax.cpp
2525
ShapeInference/AveragePool.cpp
2626
ShapeInference/CategoryMapper.cpp
27+
ShapeInference/Clip.cpp
2728
ShapeInference/Compress.cpp
2829
ShapeInference/Concat.cpp
2930
ShapeInference/Conv.cpp

0 commit comments

Comments
 (0)