Skip to content

Commit 2dde4ba

Browse files
committed
[mlir][math] Added algebraic simplification for IPowI operation.
Differential Revision: https://reviews.llvm.org/D130390
1 parent 133624a commit 2dde4ba

File tree

2 files changed

+182
-1
lines changed

2 files changed

+182
-1
lines changed

mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp

+92-1
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,100 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
112112
return failure();
113113
}
114114

115+
//----------------------------------------------------------------------------//
116+
// IPowIOp strength reduction.
117+
//----------------------------------------------------------------------------//
118+
119+
namespace {
120+
struct IPowIStrengthReduction : public OpRewritePattern<math::IPowIOp> {
121+
unsigned exponentThreshold;
122+
123+
public:
124+
IPowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3,
125+
PatternBenefit benefit = 1,
126+
ArrayRef<StringRef> generatedNames = {})
127+
: OpRewritePattern<math::IPowIOp>(context, benefit, generatedNames),
128+
exponentThreshold(exponentThreshold) {}
129+
LogicalResult matchAndRewrite(math::IPowIOp op,
130+
PatternRewriter &rewriter) const final;
131+
};
132+
} // namespace
133+
134+
LogicalResult
135+
IPowIStrengthReduction::matchAndRewrite(math::IPowIOp op,
136+
PatternRewriter &rewriter) const {
137+
Location loc = op.getLoc();
138+
Value base = op.getLhs();
139+
140+
IntegerAttr scalarExponent;
141+
DenseIntElementsAttr vectorExponent;
142+
143+
bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
144+
bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
145+
146+
// Simplify cases with known exponent value.
147+
int64_t exponentValue = 0;
148+
if (isScalar)
149+
exponentValue = scalarExponent.getInt();
150+
else if (isVector && vectorExponent.isSplat())
151+
exponentValue = vectorExponent.getSplatValue<IntegerAttr>().getInt();
152+
else
153+
return failure();
154+
155+
// Maybe broadcasts scalar value into vector type compatible with `op`.
156+
auto bcast = [&](Value value) -> Value {
157+
if (auto vec = op.getType().dyn_cast<VectorType>())
158+
return rewriter.create<vector::BroadcastOp>(loc, vec, value);
159+
return value;
160+
};
161+
162+
if (exponentValue == 0) {
163+
// Replace `ipowi(x, 0)` with `1`.
164+
Value one = rewriter.create<arith::ConstantOp>(
165+
loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1));
166+
rewriter.replaceOp(op, bcast(one));
167+
return success();
168+
}
169+
170+
bool exponentIsNegative = false;
171+
if (exponentValue < 0) {
172+
exponentIsNegative = true;
173+
exponentValue *= -1;
174+
}
175+
176+
// Bail out if `abs(exponent)` exceeds the threshold.
177+
if (exponentValue > exponentThreshold)
178+
return failure();
179+
180+
// Inverse the base for negative exponent, i.e. for
181+
// `ipowi(x, negative_exponent)` set `x` to `1 / x`.
182+
if (exponentIsNegative) {
183+
Value one = rewriter.create<arith::ConstantOp>(
184+
loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1));
185+
base = rewriter.create<arith::DivSIOp>(loc, bcast(one), base);
186+
}
187+
188+
Value result = base;
189+
// Transform to naive sequence of multiplications:
190+
// * For positive exponent case replace:
191+
// `ipowi(x, positive_exponent)`
192+
// with:
193+
// x * x * x * ...
194+
// * For negative exponent case replace:
195+
// `ipowi(x, negative_exponent)`
196+
// with:
197+
// (1 / x) * (1 / x) * (1 / x) * ...
198+
for (unsigned i = 1; i < exponentValue; ++i)
199+
result = rewriter.create<arith::MulIOp>(loc, result, base);
200+
201+
rewriter.replaceOp(op, result);
202+
return success();
203+
}
204+
115205
//----------------------------------------------------------------------------//
116206

117207
void mlir::populateMathAlgebraicSimplificationPatterns(
118208
RewritePatternSet &patterns) {
119-
patterns.add<PowFStrengthReduction>(patterns.getContext());
209+
patterns.add<PowFStrengthReduction, IPowIStrengthReduction>(
210+
patterns.getContext());
120211
}

mlir/test/Dialect/Math/algebraic-simplification.mlir

+90
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,93 @@ func.func @pow_rsqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>)
7373
%1 = math.powf %arg1, %v : vector<4xf32>
7474
return %0, %1 : f32, vector<4xf32>
7575
}
76+
77+
// CHECK-LABEL: @ipowi_zero_exp(
78+
// CHECK-SAME: %[[ARG0:.+]]: i32
79+
// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>
80+
// CHECK-SAME: -> (i32, vector<4xi32>) {
81+
func.func @ipowi_zero_exp(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>) {
82+
// CHECK: %[[CST_S:.*]] = arith.constant 1 : i32
83+
// CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
84+
// CHECK: return %[[CST_S]], %[[CST_V]]
85+
%c = arith.constant 0 : i32
86+
%v = arith.constant dense <0> : vector<4xi32>
87+
%0 = math.ipowi %arg0, %c : i32
88+
%1 = math.ipowi %arg1, %v : vector<4xi32>
89+
return %0, %1 : i32, vector<4xi32>
90+
}
91+
92+
// CHECK-LABEL: @ipowi_exp_one(
93+
// CHECK-SAME: %[[ARG0:.+]]: i32
94+
// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>
95+
// CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) {
96+
func.func @ipowi_exp_one(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) {
97+
// CHECK: %[[CST_S:.*]] = arith.constant 1 : i32
98+
// CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
99+
// CHECK: %[[SCALAR:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
100+
// CHECK: %[[VECTOR:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
101+
// CHECK: return %[[ARG0]], %[[ARG1]], %[[SCALAR]], %[[VECTOR]]
102+
%c1 = arith.constant 1 : i32
103+
%v1 = arith.constant dense <1> : vector<4xi32>
104+
%0 = math.ipowi %arg0, %c1 : i32
105+
%1 = math.ipowi %arg1, %v1 : vector<4xi32>
106+
%cm1 = arith.constant -1 : i32
107+
%vm1 = arith.constant dense <-1> : vector<4xi32>
108+
%2 = math.ipowi %arg0, %cm1 : i32
109+
%3 = math.ipowi %arg1, %vm1 : vector<4xi32>
110+
return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32>
111+
}
112+
113+
// CHECK-LABEL: @ipowi_exp_two(
114+
// CHECK-SAME: %[[ARG0:.+]]: i32
115+
// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>
116+
// CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) {
117+
func.func @ipowi_exp_two(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) {
118+
// CHECK: %[[CST_S:.*]] = arith.constant 1 : i32
119+
// CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
120+
// CHECK: %[[SCALAR0:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
121+
// CHECK: %[[VECTOR0:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
122+
// CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
123+
// CHECK: %[[SMUL:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]]
124+
// CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
125+
// CHECK: %[[VMUL:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]]
126+
// CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]]
127+
%c1 = arith.constant 2 : i32
128+
%v1 = arith.constant dense <2> : vector<4xi32>
129+
%0 = math.ipowi %arg0, %c1 : i32
130+
%1 = math.ipowi %arg1, %v1 : vector<4xi32>
131+
%cm1 = arith.constant -2 : i32
132+
%vm1 = arith.constant dense <-2> : vector<4xi32>
133+
%2 = math.ipowi %arg0, %cm1 : i32
134+
%3 = math.ipowi %arg1, %vm1 : vector<4xi32>
135+
return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32>
136+
}
137+
138+
// CHECK-LABEL: @ipowi_exp_three(
139+
// CHECK-SAME: %[[ARG0:.+]]: i32
140+
// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>
141+
// CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) {
142+
func.func @ipowi_exp_three(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) {
143+
// CHECK: %[[CST_S:.*]] = arith.constant 1 : i32
144+
// CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
145+
// CHECK: %[[SMUL0:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
146+
// CHECK: %[[SCALAR0:.*]] = arith.muli %[[SMUL0]], %[[ARG0]]
147+
// CHECK: %[[VMUL0:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
148+
// CHECK: %[[VECTOR0:.*]] = arith.muli %[[VMUL0]], %[[ARG1]]
149+
// CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
150+
// CHECK: %[[SMUL1:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]]
151+
// CHECK: %[[SMUL2:.*]] = arith.muli %[[SMUL1]], %[[SCALAR1]]
152+
// CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
153+
// CHECK: %[[VMUL1:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]]
154+
// CHECK: %[[VMUL2:.*]] = arith.muli %[[VMUL1]], %[[VECTOR1]]
155+
// CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]]
156+
%c1 = arith.constant 3 : i32
157+
%v1 = arith.constant dense <3> : vector<4xi32>
158+
%0 = math.ipowi %arg0, %c1 : i32
159+
%1 = math.ipowi %arg1, %v1 : vector<4xi32>
160+
%cm1 = arith.constant -3 : i32
161+
%vm1 = arith.constant dense <-3> : vector<4xi32>
162+
%2 = math.ipowi %arg0, %cm1 : i32
163+
%3 = math.ipowi %arg1, %vm1 : vector<4xi32>
164+
return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32>
165+
}

0 commit comments

Comments
 (0)