Skip to content

Commit ca5bb23

Browse files
authored
[mlir][tosa] Change zero points of convolution ops to required inputs (#127679)
This patch changes the input_zp and weight_zp for convolution operators to be required inputs in order to align with the TOSA Spec 1.0. Convolution operators affected are: CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D. Signed-off-by: Tai Ly <[email protected]>
1 parent be28365 commit ca5bb23

18 files changed

+580
-491
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 0 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -168,112 +168,6 @@ namespace tosa {
168168
std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
169169
Type srcElemType, int64_t zp = 0);
170170

171-
// Get zero point value from the attribute argument.
172-
LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp);
173-
174-
// Verify if zero point falls into valid range.
175-
template <typename T>
176-
LogicalResult verifyZeroPoint(Type zpElemType, int64_t zp) {
177-
if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
178-
!std::is_same_v<T, DepthwiseConv2DOp> &&
179-
!std::is_same_v<T, TransposeConv2DOp>) {
180-
return failure();
181-
}
182-
183-
if (!zpElemType.isIntOrFloat())
184-
return failure();
185-
186-
if (!zpElemType.isInteger(8) && zp != 0)
187-
return failure();
188-
189-
if (zpElemType.isSignedInteger(8) && (zp < -128 || zp > 127))
190-
return failure();
191-
192-
if (zpElemType.isUnsignedInteger(8) && (zp < 0 || zp > 255))
193-
return failure();
194-
195-
return success();
196-
}
197-
198-
// Helper type trait to determine if an operation is a tosa convolution.
199-
template <typename Op>
200-
struct IsTosaConv : std::false_type {};
201-
202-
template <>
203-
struct IsTosaConv<tosa::Conv2DOp> : std::true_type {};
204-
template <>
205-
struct IsTosaConv<tosa::DepthwiseConv2DOp> : std::true_type {};
206-
template <>
207-
struct IsTosaConv<tosa::TransposeConv2DOp> : std::true_type {};
208-
template <>
209-
struct IsTosaConv<tosa::Conv3DOp> : std::true_type {};
210-
211-
template <typename Op>
212-
constexpr bool is_tosa_conv_v = IsTosaConv<Op>::value;
213-
214-
// Helper struct to hold the zero points of a TOSA convolution operation as
215-
// named 64-bit integer fields.
216-
struct ConvZpPair {
217-
ConvZpPair(std::int64_t inputZp, std::int64_t weightZp)
218-
: inputZp(inputZp), weightZp(weightZp) {}
219-
std::int64_t inputZp;
220-
std::int64_t weightZp;
221-
};
222-
223-
// Helper function which attempts to extract the zero points from a TOSA
224-
// convolution by matching them against defining ops which should be tosa.const
225-
// operations.
226-
//
227-
// There are three possible results:
228-
// 1. Failed to extract the zero-points i.e. they should exist and don't or they
229-
// do exist but are invalid.
230-
// 2. Succeeded in extracting zero-points.
231-
// 3. Zero points are "empty" and meaningless for this op i.e. non-quantized
232-
// convolution.
233-
using FailOrMaybeZP = llvm::FailureOr<std::optional<ConvZpPair>>;
234-
template <typename TosaConvOp>
235-
std::enable_if_t<is_tosa_conv_v<TosaConvOp>, FailOrMaybeZP>
236-
extractConvZpPair(TosaConvOp op, PatternRewriter &rewriter) {
237-
// Strictly speaking the base TOSA spec requires that for non int8 types
238-
// zero points must be zero. However, in the dialect these operands are
239-
// optional and only required for int8. They have no semantic meaning for
240-
// non-quantized types and can therefore be safely ignored. This is case 3.
241-
if (auto opElementTY =
242-
cast<ShapedType>(op->getOperand(0).getType()).getElementType();
243-
!opElementTY.isInteger(8))
244-
return FailOrMaybeZP(std::nullopt);
245-
246-
// Now we know we should have a zero point check it is valid.
247-
if (!op.getInputZp())
248-
return rewriter.notifyMatchFailure(op, "missing input zero point");
249-
250-
// Helper to extract the zero point by matching its definition against a
251-
// constant.
252-
auto extractZeroPoint = [](Value zpValue) -> std::optional<int64_t> {
253-
ElementsAttr zpAttr;
254-
if (!matchPattern(zpValue, m_Constant(&zpAttr)))
255-
return std::nullopt;
256-
257-
int64_t zp;
258-
if (tosa::getZeroPoint(zpAttr, zp).failed())
259-
return std::nullopt;
260-
261-
return std::make_optional(zp);
262-
};
263-
264-
auto maybeInputZp = extractZeroPoint(op.getInputZp());
265-
if (!maybeInputZp)
266-
return rewriter.notifyMatchFailure(op, "unable to extract input zp");
267-
268-
if (!op.getWeightZp())
269-
return rewriter.notifyMatchFailure(op, "missing weight zero point");
270-
271-
auto maybeWeightZp = extractZeroPoint(op.getWeightZp());
272-
if (!maybeWeightZp)
273-
return rewriter.notifyMatchFailure(op, "unable to extract weight zp");
274-
275-
return std::make_optional<ConvZpPair>(*maybeInputZp, *maybeWeightZp);
276-
}
277171
} // namespace tosa
278172
} // namespace mlir
279173

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,9 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
116116
Tosa_Tensor4D:$input,
117117
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
118118
Tosa_Tensor1D:$bias,
119-
Optional<Tosa_ScalarTensor>:$input_zp,
120-
Optional<Tosa_ScalarTensor>:$weight_zp,
119+
Tosa_ScalarTensor:$input_zp,
120+
Tosa_ScalarTensor:$weight_zp,
121+
121122
Tosa_IntArrayAttr4:$pad,
122123
Tosa_IntArrayAttr2:$stride,
123124
Tosa_IntArrayAttr2:$dilation,
@@ -134,6 +135,13 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
134135
Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
135136
];
136137

138+
let extraClassDeclaration = [{
139+
LogicalResult getInputZeroPoint(int64_t &zp);
140+
LogicalResult getWeightZeroPoint(int64_t &zp);
141+
LogicalResult verifyInputZeroPoint(int64_t zp);
142+
LogicalResult verifyWeightZeroPoint(int64_t zp);
143+
}];
144+
137145
let builders = [Tosa_ConvOpQuantInfoBuilder];
138146
let hasVerifier = 1;
139147
}
@@ -153,8 +161,9 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
153161
Tosa_Tensor5D:$input,
154162
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
155163
Tosa_Tensor1D:$bias,
156-
Optional<Tosa_ScalarTensor>:$input_zp,
157-
Optional<Tosa_ScalarTensor>:$weight_zp,
164+
Tosa_ScalarTensor:$input_zp,
165+
Tosa_ScalarTensor:$weight_zp,
166+
158167
Tosa_IntArrayAttr6:$pad,
159168
Tosa_IntArrayAttr3:$stride,
160169
Tosa_IntArrayAttr3:$dilation,
@@ -171,6 +180,13 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
171180
Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
172181
];
173182

183+
let extraClassDeclaration = [{
184+
LogicalResult getInputZeroPoint(int64_t &zp);
185+
LogicalResult getWeightZeroPoint(int64_t &zp);
186+
LogicalResult verifyInputZeroPoint(int64_t zp);
187+
LogicalResult verifyWeightZeroPoint(int64_t zp);
188+
}];
189+
174190
let builders = [Tosa_ConvOpQuantInfoBuilder];
175191
let hasVerifier = 1;
176192
}
@@ -191,8 +207,9 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
191207
Tosa_Tensor4D:$input,
192208
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
193209
Tosa_Tensor1D:$bias,
194-
Optional<Tosa_ScalarTensor>:$input_zp,
195-
Optional<Tosa_ScalarTensor>:$weight_zp,
210+
Tosa_ScalarTensor:$input_zp,
211+
Tosa_ScalarTensor:$weight_zp,
212+
196213
Tosa_IntArrayAttr4:$pad,
197214
Tosa_IntArrayAttr2:$stride,
198215
Tosa_IntArrayAttr2:$dilation,
@@ -209,6 +226,13 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
209226
Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
210227
];
211228

229+
let extraClassDeclaration = [{
230+
LogicalResult getInputZeroPoint(int64_t &zp);
231+
LogicalResult getWeightZeroPoint(int64_t &zp);
232+
LogicalResult verifyInputZeroPoint(int64_t zp);
233+
LogicalResult verifyWeightZeroPoint(int64_t zp);
234+
}];
235+
212236
let builders = [Tosa_ConvOpQuantInfoBuilder];
213237
let hasVerifier = 1;
214238
}
@@ -379,8 +403,9 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
379403
Tosa_Tensor4D:$input,
380404
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
381405
Tosa_Tensor1D:$bias,
382-
Optional<Tosa_ScalarTensor>:$input_zp,
383-
Optional<Tosa_ScalarTensor>:$weight_zp,
406+
Tosa_ScalarTensor:$input_zp,
407+
Tosa_ScalarTensor:$weight_zp,
408+
384409
Tosa_IntArrayAttr4:$out_pad,
385410
Tosa_IntArrayAttr2:$stride,
386411
Tosa_IntArrayAttr4:$out_shape,
@@ -397,6 +422,13 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
397422
Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
398423
];
399424

425+
let extraClassDeclaration = [{
426+
LogicalResult getInputZeroPoint(int64_t &zp);
427+
LogicalResult getWeightZeroPoint(int64_t &zp);
428+
LogicalResult verifyInputZeroPoint(int64_t zp);
429+
LogicalResult verifyWeightZeroPoint(int64_t zp);
430+
}];
431+
400432
let builders = [Tosa_TransConvOpQuantInfoBuilder];
401433
let hasVerifier = 1;
402434
}

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -259,11 +259,21 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
259259
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
260260
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
261261

262-
auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
263-
if (llvm::failed(failureOrMaybeZps))
264-
return failure();
262+
// Get and verify zero points.
263+
int64_t inputZpVal;
264+
int64_t weightZpVal;
265+
266+
if (op.getInputZeroPoint(inputZpVal).failed() ||
267+
op.getWeightZeroPoint(weightZpVal).failed())
268+
return rewriter.notifyMatchFailure(
269+
op, "bail out if zero points cannot statically be determined");
270+
271+
if (op.verifyInputZeroPoint(inputZpVal).failed() ||
272+
op.verifyWeightZeroPoint(weightZpVal).failed())
273+
return rewriter.notifyMatchFailure(
274+
op, "zero point must be zero for non-int8 integer types");
265275

266-
auto maybeZps = failureOrMaybeZps.value();
276+
bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
267277

268278
if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
269279
return rewriter.notifyMatchFailure(
@@ -289,19 +299,19 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
289299

290300
// Apply padding as necessary.
291301
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
292-
if (maybeZps) {
302+
if (hasZp) {
293303
int64_t intMin =
294304
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
295305
.getSExtValue();
296306
int64_t intMax =
297307
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
298308
.getSExtValue();
299309

300-
if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
310+
if (inputZpVal < intMin || inputZpVal > intMax)
301311
return rewriter.notifyMatchFailure(
302312
op, "tosa.conv op quantization has zp outside of input range");
303313

304-
zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
314+
zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
305315
}
306316

307317
llvm::SmallVector<int64_t> pad;
@@ -314,8 +324,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
314324
// For 2D convolutions, we need to check if the target convolution op
315325
// wants a HWCF kernel layout.
316326
bool wantHwcf =
317-
maybeZps ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
318-
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
327+
hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
328+
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
319329
if (wantHwcf) {
320330
// Transpose the kernel to match dimension ordering of the linalg
321331
// convolution operation.
@@ -372,9 +382,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
372382
Value broadcastBias =
373383
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
374384

375-
if (maybeZps) {
376-
auto iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
377-
auto kZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
385+
if (hasZp) {
386+
auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
387+
auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
378388

379389
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
380390
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
@@ -437,31 +447,40 @@ class DepthwiseConvConverter
437447
/*inputSizeDims=*/{1, 2},
438448
/*kernelSizeDims=*/{0, 1}, rewriter);
439449

440-
auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
441-
if (llvm::failed(failureOrMaybeZps))
442-
return failure();
450+
// Get and verify zero points.
451+
int64_t inputZpVal;
452+
int64_t weightZpVal;
453+
454+
if (op.getInputZeroPoint(inputZpVal).failed() ||
455+
op.getWeightZeroPoint(weightZpVal).failed())
456+
return rewriter.notifyMatchFailure(
457+
op, "bail out if zero points cannot statically be determined");
443458

444-
auto maybeZps = failureOrMaybeZps.value();
459+
if (op.verifyInputZeroPoint(inputZpVal).failed() ||
460+
op.verifyWeightZeroPoint(weightZpVal).failed())
461+
return rewriter.notifyMatchFailure(
462+
op, "zero point must be zero for non-int8 integer types");
445463

464+
bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
446465
auto weightShape = weightTy.getShape();
447466
auto resultShape = resultTy.getShape();
448467

449468
// Apply padding as necessary.
450469
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
451-
if (maybeZps) {
470+
if (hasZp) {
452471
int64_t intMin =
453472
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
454473
.getSExtValue();
455474
int64_t intMax =
456475
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
457476
.getSExtValue();
458477

459-
if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
478+
if (inputZpVal < intMin || inputZpVal > intMax)
460479
return rewriter.notifyMatchFailure(
461480
op, "tosa.depthwise_conv op quantization has zp outside of input "
462481
"range");
463482

464-
zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
483+
zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
465484
}
466485

467486
llvm::SmallVector<int64_t> pad;
@@ -501,7 +520,7 @@ class DepthwiseConvConverter
501520
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
502521
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
503522

504-
if (!maybeZps) {
523+
if (!hasZp) {
505524
Value conv = rewriter
506525
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
507526
loc, linalgConvTy, ValueRange{input, weight},
@@ -528,8 +547,8 @@ class DepthwiseConvConverter
528547
.getResult(0);
529548
rewriter.replaceOp(op, result);
530549
} else {
531-
IntegerAttr iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
532-
IntegerAttr wZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
550+
IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
551+
IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
533552
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
534553
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, wZp);
535554
Value conv =

0 commit comments

Comments
 (0)