@@ -28,73 +28,67 @@ struct ONNXClipOpLowering : public ConversionPattern {
28
28
LogicalResult matchAndRewrite (Operation *op, ArrayRef<Value> operands,
29
29
ConversionPatternRewriter &rewriter) const final {
30
30
Location loc = op->getLoc ();
31
- Value input = operands[0 ];
32
- Value min = operands[1 ];
33
- Value max = operands[2 ];
31
+ ONNXClipOp clipOp = cast<ONNXClipOp>(op);
32
+ MemRefType memRefType = convertToMemRefType (*op->result_type_begin ());
34
33
35
- // Insert an allocation and deallocation for the result of this operation.
36
- auto memRefType = convertToMemRefType (*op->result_type_begin ());
37
-
38
- Value alloc;
39
- bool insertDealloc = checkInsertDealloc (op);
34
+ ONNXClipOpAdaptor operandAdaptor (operands);
35
+ ONNXClipOpShapeHelper shapeHelper (&clipOp, &rewriter,
36
+ getDenseElementAttributeFromKrnlValue,
37
+ loadDenseElementArrayValueAtIndex);
38
+ auto shapeComputed = shapeHelper.computeShape (operandAdaptor);
39
+ assert (succeeded (shapeComputed));
40
40
41
- if (hasAllConstantDimensions (memRefType))
42
- alloc = insertAllocAndDealloc (memRefType, loc, rewriter, insertDealloc);
43
- else
44
- alloc = insertAllocAndDealloc (
45
- memRefType, loc, rewriter, insertDealloc, input);
41
+ Value input = operandAdaptor.input ();
42
+ Value min = operandAdaptor.min ();
43
+ Value max = operandAdaptor.max ();
46
44
47
- SmallVector<Value, 4 > loopIVs;
48
- // Only create krnl.iterate if one of the operands is not scalar tensor.
45
+ // Insert an allocation and deallocation for the result of this operation.
46
+ bool insertDealloc = checkInsertDealloc (op);
47
+ Value alloc =
48
+ (hasAllConstantDimensions (memRefType))
49
+ ? insertAllocAndDealloc (memRefType, loc, rewriter, insertDealloc)
50
+ : insertAllocAndDealloc (
51
+ memRefType, loc, rewriter, insertDealloc, input);
52
+
53
+ auto computeResult =
54
+ [&](MultiDialectBuilder<KrnlBuilder, MathBuilder> &create,
55
+ const ValueRange &indices) {
56
+ Value loadedVal = create.krnl .load (input, indices);
57
+ Value res = loadedVal;
58
+ if (!min.getType ().isa <NoneType>()) {
59
+ Value minVal = create.krnl .load (min);
60
+ Value lessThanMin = create.math .slt (res, minVal);
61
+ res = create.math .select (lessThanMin, minVal, res);
62
+ }
63
+ if (!max.getType ().isa <NoneType>()) {
64
+ Value maxVal = create.krnl .load (max);
65
+ Value lessThanMax = create.math .slt (res, maxVal);
66
+ res = create.math .select (lessThanMax, res, maxVal);
67
+ }
68
+ create.krnl .store (res, alloc, indices);
69
+ };
70
+
71
+ // Create a loop only is one of the operands is not a scalar tensor.
49
72
if (!hasAllScalarValues (operands)) {
50
- // Create iterateOp & get block within iterate op.
51
- BuildKrnlLoop loops (rewriter, loc, memRefType.getRank ());
52
- loops.createDefineAndIterateOp (input);
53
- Block *iterationBlock = loops.getIterateBlock ();
54
-
55
- // Insert instructions inside the KernelIterateOp body.
56
- rewriter.setInsertionPointToStart (iterationBlock);
57
-
58
- // Handle the operation:
59
- for (auto arg : iterationBlock->getArguments ())
60
- loopIVs.push_back (arg);
61
- }
62
-
63
- // Load unary first operand.
64
- MultiDialectBuilder<KrnlBuilder, MathBuilder> create (rewriter, loc);
65
- Value loadedVal = create.krnl .load (input, loopIVs);
66
- Type inputType = loadedVal.getType ();
67
- Value res = loadedVal;
68
-
69
- if (inputType.isa <FloatType>()) {
70
- if (!min.getType ().isa <NoneType>()) {
71
- Value minVal = create.krnl .load (min);
72
- Value lessThanMin = create.math .slt (res, minVal);
73
- res = create.math .select (lessThanMin, minVal, res);
74
- }
75
- if (!max.getType ().isa <NoneType>()) {
76
- Value maxVal = create.krnl .load (max);
77
- Value lessThanMax = create.math .slt (res, maxVal);
78
- res = create.math .select (lessThanMax, res, maxVal);
79
- }
80
- } else if (inputType.isa <IntegerType>()) {
81
- if (!min.getType ().isa <NoneType>()) {
82
- Value minVal = create.krnl .load (min);
83
- Value lessThanMin = create.math .slt (res, minVal);
84
- res = create.math .select (lessThanMin, minVal, res);
85
- }
86
- if (!max.getType ().isa <NoneType>()) {
87
- Value maxVal = create.krnl .load (max);
88
- Value lessThanMax = create.math .slt (res, maxVal);
89
- res = create.math .select (lessThanMax, res, maxVal);
90
- }
73
+ KrnlBuilder createKrnl (rewriter, loc);
74
+ uint64_t numLoops = memRefType.getRank ();
75
+ ValueRange loopDef = createKrnl.defineLoops (numLoops);
76
+
77
+ SmallVector<IndexExpr, 4 > lbs (numLoops, LiteralIndexExpr (0 ));
78
+ SmallVector<IndexExpr, 4 > ubs;
79
+ for (uint64_t i = 0 ; i < numLoops; ++i)
80
+ ubs.emplace_back (shapeHelper.dimsForOutput ()[i]);
81
+
82
+ createKrnl.iterateIE (loopDef, loopDef, lbs, ubs,
83
+ [&](KrnlBuilder &createKrnl, ValueRange indices) {
84
+ MultiDialectBuilder<KrnlBuilder, MathBuilder> create (createKrnl);
85
+ computeResult (create, indices);
86
+ });
91
87
} else {
92
- llvm_unreachable (" unsupported element type" );
88
+ MultiDialectBuilder<KrnlBuilder, MathBuilder> create (rewriter, loc);
89
+ computeResult (create, {});
93
90
}
94
91
95
- // Store result in the resulting array.
96
- create.krnl .store (res, alloc, loopIVs);
97
-
98
92
rewriter.replaceOp (op, alloc);
99
93
return success ();
100
94
}
0 commit comments