@@ -259,11 +259,21 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
259
259
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr ();
260
260
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr ();
261
261
262
- auto failureOrMaybeZps = extractConvZpPair (op, rewriter);
263
- if (llvm::failed (failureOrMaybeZps))
264
- return failure ();
262
+ // Get and verify zero points.
263
+ int64_t inputZpVal;
264
+ int64_t weightZpVal;
265
+
266
+ if (op.getInputZeroPoint (inputZpVal).failed () ||
267
+ op.getWeightZeroPoint (weightZpVal).failed ())
268
+ return rewriter.notifyMatchFailure (
269
+ op, " bail out if zero points cannot statically be determined" );
270
+
271
+ if (op.verifyInputZeroPoint (inputZpVal).failed () ||
272
+ op.verifyWeightZeroPoint (weightZpVal).failed ())
273
+ return rewriter.notifyMatchFailure (
274
+ op, " zero point must be zero for non-int8 integer types" );
265
275
266
- auto maybeZps = failureOrMaybeZps. value ( );
276
+ bool hasZp = (inputZpVal != 0 ) || (weightZpVal != 0 );
267
277
268
278
if (!weightTy.hasStaticShape () || !biasTy.hasStaticShape ())
269
279
return rewriter.notifyMatchFailure (
@@ -289,19 +299,19 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
289
299
290
300
// Apply padding as necessary.
291
301
TypedAttr zeroAttr = rewriter.getZeroAttr (inputETy);
292
- if (maybeZps ) {
302
+ if (hasZp ) {
293
303
int64_t intMin =
294
304
APInt::getSignedMinValue (inputETy.getIntOrFloatBitWidth ())
295
305
.getSExtValue ();
296
306
int64_t intMax =
297
307
APInt::getSignedMaxValue (inputETy.getIntOrFloatBitWidth ())
298
308
.getSExtValue ();
299
309
300
- if (maybeZps-> inputZp < intMin || maybeZps-> inputZp > intMax)
310
+ if (inputZpVal < intMin || inputZpVal > intMax)
301
311
return rewriter.notifyMatchFailure (
302
312
op, " tosa.conv op quantization has zp outside of input range" );
303
313
304
- zeroAttr = rewriter.getIntegerAttr (inputETy, maybeZps-> inputZp );
314
+ zeroAttr = rewriter.getIntegerAttr (inputETy, inputZpVal );
305
315
}
306
316
307
317
llvm::SmallVector<int64_t > pad;
@@ -314,8 +324,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
314
324
// For 2D convolutions, we need to check if the target convolution op
315
325
// wants a HWCF kernel layout.
316
326
bool wantHwcf =
317
- maybeZps ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
318
- : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
327
+ hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
328
+ : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
319
329
if (wantHwcf) {
320
330
// Transpose the kernel to match dimension ordering of the linalg
321
331
// convolution operation.
@@ -372,9 +382,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
372
382
Value broadcastBias =
373
383
linalgBroadcastAndMaybeExtSI (rewriter, loc, bias, biasEmptyTensor);
374
384
375
- if (maybeZps ) {
376
- auto iZp = rewriter.getI32IntegerAttr (maybeZps-> inputZp );
377
- auto kZp = rewriter.getI32IntegerAttr (maybeZps-> weightZp );
385
+ if (hasZp ) {
386
+ auto iZp = rewriter.getI32IntegerAttr (inputZpVal );
387
+ auto kZp = rewriter.getI32IntegerAttr (weightZpVal );
378
388
379
389
auto iZpVal = rewriter.create <arith::ConstantOp>(loc, iZp);
380
390
auto kZpVal = rewriter.create <arith::ConstantOp>(loc, kZp );
@@ -437,31 +447,40 @@ class DepthwiseConvConverter
437
447
/* inputSizeDims=*/ {1 , 2 },
438
448
/* kernelSizeDims=*/ {0 , 1 }, rewriter);
439
449
440
- auto failureOrMaybeZps = extractConvZpPair (op, rewriter);
441
- if (llvm::failed (failureOrMaybeZps))
442
- return failure ();
450
+ // Get and verify zero points.
451
+ int64_t inputZpVal;
452
+ int64_t weightZpVal;
453
+
454
+ if (op.getInputZeroPoint (inputZpVal).failed () ||
455
+ op.getWeightZeroPoint (weightZpVal).failed ())
456
+ return rewriter.notifyMatchFailure (
457
+ op, " bail out if zero points cannot statically be determined" );
443
458
444
- auto maybeZps = failureOrMaybeZps.value ();
459
+ if (op.verifyInputZeroPoint (inputZpVal).failed () ||
460
+ op.verifyWeightZeroPoint (weightZpVal).failed ())
461
+ return rewriter.notifyMatchFailure (
462
+ op, " zero point must be zero for non-int8 integer types" );
445
463
464
+ bool hasZp = (inputZpVal != 0 ) || (weightZpVal != 0 );
446
465
auto weightShape = weightTy.getShape ();
447
466
auto resultShape = resultTy.getShape ();
448
467
449
468
// Apply padding as necessary.
450
469
TypedAttr zeroAttr = rewriter.getZeroAttr (inputETy);
451
- if (maybeZps ) {
470
+ if (hasZp ) {
452
471
int64_t intMin =
453
472
APInt::getSignedMinValue (inputETy.getIntOrFloatBitWidth ())
454
473
.getSExtValue ();
455
474
int64_t intMax =
456
475
APInt::getSignedMaxValue (inputETy.getIntOrFloatBitWidth ())
457
476
.getSExtValue ();
458
477
459
- if (maybeZps-> inputZp < intMin || maybeZps-> inputZp > intMax)
478
+ if (inputZpVal < intMin || inputZpVal > intMax)
460
479
return rewriter.notifyMatchFailure (
461
480
op, " tosa.depthwise_conv op quantization has zp outside of input "
462
481
" range" );
463
482
464
- zeroAttr = rewriter.getIntegerAttr (inputETy, maybeZps-> inputZp );
483
+ zeroAttr = rewriter.getIntegerAttr (inputETy, inputZpVal );
465
484
}
466
485
467
486
llvm::SmallVector<int64_t > pad;
@@ -501,7 +520,7 @@ class DepthwiseConvConverter
501
520
indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
502
521
indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
503
522
504
- if (!maybeZps ) {
523
+ if (!hasZp ) {
505
524
Value conv = rewriter
506
525
.create <linalg::DepthwiseConv2DNhwcHwcmOp>(
507
526
loc, linalgConvTy, ValueRange{input, weight},
@@ -528,8 +547,8 @@ class DepthwiseConvConverter
528
547
.getResult (0 );
529
548
rewriter.replaceOp (op, result);
530
549
} else {
531
- IntegerAttr iZp = rewriter.getI32IntegerAttr (maybeZps-> inputZp );
532
- IntegerAttr wZp = rewriter.getI32IntegerAttr (maybeZps-> weightZp );
550
+ IntegerAttr iZp = rewriter.getI32IntegerAttr (inputZpVal );
551
+ IntegerAttr wZp = rewriter.getI32IntegerAttr (weightZpVal );
533
552
auto iZpVal = rewriter.create <arith::ConstantOp>(loc, iZp);
534
553
auto kZpVal = rewriter.create <arith::ConstantOp>(loc, wZp);
535
554
Value conv =
0 commit comments