Skip to content

Commit dfbadfc

Browse files
[mlir][tosa] Change MatMul zero-point to inputs (#130332)
* Change zero-point attributes to inputs * Fix relevant mlir tests * Enhance ShardingInterface in MatMul Signed-off-by: Udaya Ranga <[email protected]> Co-authored-by: Udaya Ranga <[email protected]>
1 parent cb3ce30 commit dfbadfc

15 files changed

+238
-100
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc

+8-6
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ profileComplianceMap = {
3535
{fp16T, fp16T, fp32T, fp32T},
3636
{fp32T, fp32T, fp32T, fp32T}}}}},
3737
{"tosa.matmul",
38-
{{{Profile::pro_int}, {{i8T, i8T, i32T}}},
38+
{{{Profile::pro_int}, {{i8T, i8T, i8T, i8T, i32T}}},
3939
{{Profile::pro_fp},
40-
{{fp16T, fp16T, fp16T}, {fp16T, fp16T, fp32T}, {fp32T, fp32T, fp32T}}}}},
40+
{{fp16T, fp16T, fp16T, fp16T, fp16T},
41+
{fp16T, fp16T, fp16T, fp16T, fp32T},
42+
{fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
4143
{"tosa.max_pool2d",
4244
{{{Profile::pro_int}, {{i8T, i8T}}},
4345
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
@@ -273,10 +275,10 @@ extensionComplianceMap = {
273275
{{Extension::int16}, {{i16T, i8T, i48T, i48T}}},
274276
{{Extension::bf16}, {{bf16T, bf16T, fp32T, fp32T}}}}},
275277
{"tosa.matmul",
276-
{{{Extension::int16}, {{i16T, i16T, i48T}}},
277-
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T}}},
278-
{{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T}}},
279-
{{Extension::bf16}, {{bf16T, bf16T, fp32T}}}}},
278+
{{{Extension::int16}, {{i16T, i16T, i16T, i16T, i48T}}},
279+
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T}}},
280+
{{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T}}},
281+
{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T, fp32T}}}}},
280282
{"tosa.max_pool2d",
281283
{{{Extension::int16}, {{i16T, i16T}}},
282284
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

+9-2
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
311311
let arguments = (ins
312312
Tosa_Tensor3D:$a,
313313
Tosa_Tensor3D:$b,
314-
OptionalAttr<I32Attr>:$a_zp,
315-
OptionalAttr<I32Attr>:$b_zp
314+
Tosa_ScalarIntOrFloatTensor:$a_zp,
315+
Tosa_ScalarIntOrFloatTensor:$b_zp
316316
);
317317

318318
let results = (outs
@@ -324,6 +324,13 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
324324
Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
325325
];
326326

327+
let extraClassDeclaration = [{
328+
FailureOr<int64_t> getAZeroPoint();
329+
FailureOr<int64_t> getBZeroPoint();
330+
LogicalResult verifyAZeroPoint(int64_t zp);
331+
LogicalResult verifyBZeroPoint(int64_t zp);
332+
}];
333+
327334
let builders = [Tosa_MatMulOpQuantInfoBuilder];
328335
let hasVerifier = 1;
329336
}

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

+32-9
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
270270
return rewriter.notifyMatchFailure(
271271
op, "weight zero point cannot be statically determined");
272272

273-
int64_t inputZpVal = *maybeIZp;
274-
int64_t weightZpVal = *maybeWZp;
273+
const int64_t inputZpVal = *maybeIZp;
274+
const int64_t weightZpVal = *maybeWZp;
275275

276276
if (op.verifyInputZeroPoint(inputZpVal).failed())
277277
return rewriter.notifyMatchFailure(
@@ -466,8 +466,8 @@ class DepthwiseConvConverter
466466
return rewriter.notifyMatchFailure(
467467
op, "weight zero point cannot be statically determined");
468468

469-
int64_t inputZpVal = *maybeIZp;
470-
int64_t weightZpVal = *maybeWZp;
469+
const int64_t inputZpVal = *maybeIZp;
470+
const int64_t weightZpVal = *maybeWZp;
471471

472472
if (op.verifyInputZeroPoint(inputZpVal).failed())
473473
return rewriter.notifyMatchFailure(
@@ -621,15 +621,38 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
621621
.create<linalg::FillOp>(loc, ValueRange{zero},
622622
ValueRange{emptyTensor})
623623
.result();
624-
if (!op.getAZp() && !op.getBZp()) {
624+
625+
FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
626+
FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
627+
if (failed(maybeAZp))
628+
return rewriter.notifyMatchFailure(
629+
op, "input a zero point cannot be statically determined");
630+
if (failed(maybeBZp))
631+
return rewriter.notifyMatchFailure(
632+
op, "input b zero point cannot be statically determined");
633+
634+
const int64_t aZpVal = *maybeAZp;
635+
const int64_t bZpVal = *maybeBZp;
636+
637+
if (op.verifyAZeroPoint(aZpVal).failed())
638+
return rewriter.notifyMatchFailure(
639+
op, "input a zero point must be zero for non-int8 integer types");
640+
641+
if (op.verifyBZeroPoint(bZpVal).failed())
642+
return rewriter.notifyMatchFailure(
643+
op, "input b zero point must be zero for non-int8 integer types");
644+
645+
if (aZpVal == 0 && bZpVal == 0) {
625646
rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
626647
op, TypeRange{op.getType()},
627648
ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
628649
return success();
629650
}
630651

631-
auto aZp = rewriter.create<arith::ConstantOp>(loc, op.getAZpAttr());
632-
auto bZp = rewriter.create<arith::ConstantOp>(loc, op.getBZpAttr());
652+
auto aZp = rewriter.create<arith::ConstantOp>(
653+
loc, rewriter.getI32IntegerAttr(aZpVal));
654+
auto bZp = rewriter.create<arith::ConstantOp>(
655+
loc, rewriter.getI32IntegerAttr(bZpVal));
633656
rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
634657
op, TypeRange{op.getType()},
635658
ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
@@ -834,8 +857,8 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
834857
return rewriter.notifyMatchFailure(
835858
op, "output zero point could not be statically determined");
836859

837-
int64_t inputZpVal = *maybeIZp;
838-
int64_t outputZpVal = *maybeOZp;
860+
const int64_t inputZpVal = *maybeIZp;
861+
const int64_t outputZpVal = *maybeOZp;
839862

840863
// Apply padding as necessary.
841864
llvm::SmallVector<int64_t> pad;

mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ struct MatMulOpSharding
5555
SmallVector<AffineMap> maps;
5656
maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx));
5757
maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx));
58+
maps.push_back(AffineMap::get(0, 0, {}, ctx));
59+
maps.push_back(AffineMap::get(0, 0, {}, ctx));
5860
maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx));
5961
return maps;
6062
}

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

+39-25
Original file line numberDiff line numberDiff line change
@@ -629,23 +629,13 @@ buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
629629
static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
630630
OperationState &result, Type outputType,
631631
Value a, Value b) {
632-
result.addOperands({a, b});
633-
auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
632+
auto zps = createZPsAsConst(builder, a, b);
633+
result.addOperands({a, b, zps.first, zps.second});
634634

635-
if (quantAttr) {
636-
result.addAttribute("a_zp", builder.getI32IntegerAttr(
637-
static_cast<int32_t>(quantAttr.getAZp())));
638-
result.addAttribute("b_zp", builder.getI32IntegerAttr(
639-
static_cast<int32_t>(quantAttr.getBZp())));
640-
641-
auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
642-
assert(inputType && "Input must be a shaped tensor type!");
643-
644-
auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
645-
inputType.getElementType());
646-
assert(inputQType && "Tensor must have quantized datatype!");
647-
648-
unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
635+
Type finalOutputType{outputType};
636+
if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
637+
auto eType = getStorageElementTypeOrSelf(a.getType());
638+
auto inputBits = eType.getIntOrFloatBitWidth();
649639

650640
auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
651641
assert(outputShapedType && "Output must be a shaped type");
@@ -655,11 +645,10 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
655645
accElementType = builder.getIntegerType(48);
656646
else
657647
accElementType = builder.getI32Type();
658-
auto accType = outputShapedType.clone(accElementType);
659-
result.addTypes(accType);
660-
} else {
661-
result.addTypes(outputType);
648+
649+
finalOutputType = outputShapedType.clone(accElementType);
662650
}
651+
result.addTypes(finalOutputType);
663652
}
664653

665654
/// Both the tosa.avg_pool2d and unary ops use the same
@@ -1140,16 +1129,39 @@ LogicalResult MatMulOp::verify() {
11401129
return emitOpError("expect quantized operands to have same widths, got ")
11411130
<< aQuantWidth << " and " << bQuantWidth;
11421131
}
1132+
} else {
1133+
// non-quantized element types
1134+
if (aElementType != bElementType) {
1135+
return emitOpError("expect same element type for inputs a and b, got ")
1136+
<< aElementType << " and " << bElementType;
1137+
}
1138+
}
11431139

1144-
return success();
1140+
// check a_zp and b_zp
1141+
auto aEType = getStorageElementTypeOrSelf(aType);
1142+
auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
1143+
if (aEType != aZpEType) {
1144+
return emitOpError("expect input a and a_zp have the same "
1145+
"element type, got ")
1146+
<< aEType << " and " << aZpEType;
11451147
}
11461148

1147-
// non-quantized element types
1148-
if (aElementType != bElementType) {
1149-
return emitOpError("expect same element type for inputs a and b, got ")
1150-
<< aElementType << " and " << bElementType;
1149+
auto bEType = getStorageElementTypeOrSelf(bType);
1150+
auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
1151+
if (bEType != bZpEType) {
1152+
return emitOpError("expect input b and b_zp have the same "
1153+
"element type, got ")
1154+
<< bEType << " and " << bZpEType;
11511155
}
11521156

1157+
FailureOr<int64_t> maybeAZp = getAZeroPoint();
1158+
if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1159+
return failure();
1160+
1161+
FailureOr<int64_t> maybeBZp = getBZeroPoint();
1162+
if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1163+
return failure();
1164+
11531165
return success();
11541166
}
11551167

@@ -1714,6 +1726,8 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input)
17141726
ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
17151727
ZERO_POINT_HELPER(AvgPool2dOp, Input)
17161728
ZERO_POINT_HELPER(AvgPool2dOp, Output)
1729+
ZERO_POINT_HELPER(MatMulOp, A)
1730+
ZERO_POINT_HELPER(MatMulOp, B)
17171731
#undef ZERO_POINT_HELPER
17181732

17191733
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,15 @@ void ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
178178
addValue(op.getOutput());
179179
}
180180

181+
template <>
182+
void ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
183+
addValue(op.getA());
184+
addValue(op.getB());
185+
addValue(op.getAZp());
186+
addValue(op.getBZp());
187+
addValue(op.getOutput());
188+
}
189+
181190
LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
182191
// This helper function only populates the info for the customised operands.
183192
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
@@ -218,6 +227,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
218227
POPULATE_PROFILE_INFO_CUSTOM(Resize)
219228
POPULATE_PROFILE_INFO_CUSTOM(Select)
220229
POPULATE_PROFILE_INFO_CUSTOM(Rescale)
230+
POPULATE_PROFILE_INFO_CUSTOM(MatMul)
221231

222232
// Type Invariant Extension, a capability extension that is independent
223233
// of the data type, meaning any compatible type can be used. No type
@@ -235,7 +245,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
235245
POPULATE_PROFILE_INFO_COMMON(Cast)
236246
POPULATE_PROFILE_INFO_COMMON(Const)
237247
POPULATE_PROFILE_INFO_COMMON(ArgMax)
238-
POPULATE_PROFILE_INFO_COMMON(MatMul)
239248
POPULATE_PROFILE_INFO_COMMON(Sub)
240249
POPULATE_PROFILE_INFO_COMMON(Maximum)
241250
POPULATE_PROFILE_INFO_COMMON(Minimum)

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

+18-6
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor
88
// CHECK: [[INIT:%.+]] = tensor.empty()
99
// CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : f32) outs([[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
1010
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x6xf32>) outs([[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
11-
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x3xf32>, tensor<1x3x6xf32>) -> tensor<1x5x6xf32>
11+
%a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
12+
%b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
13+
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xf32>, tensor<1x3x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x6xf32>
1214
return %0 : tensor<1x5x6xf32>
1315
}
1416

@@ -23,7 +25,9 @@ func.func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) ->
2325
// CHECK: [[ONE:%.+]] = arith.constant 1
2426
// CHECK: [[TWO:%.+]] = arith.constant 2
2527
// CHECK: linalg.quantized_batch_matmul ins(%arg0, %arg1, [[ONE]], [[TWO]] : tensor<1x5x3xi8>, tensor<1x3x6xi8>, i32, i32) outs([[FILLED]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32>
26-
%0 = tosa.matmul %arg0, %arg1 {a_zp = 1 : i32, b_zp = 2 : i32} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> tensor<1x5x6xi32>
28+
%a_zp = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
29+
%b_zp = "tosa.const"() <{values = dense<2> : tensor<1xi8>}> : () -> tensor<1xi8>
30+
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xi8>, tensor<1x3x6xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x5x6xi32>
2731
return %0 : tensor<1x5x6xi32>
2832
}
2933

@@ -37,7 +41,9 @@ func.func @matmul_dyn_batch(%arg0: tensor<?x5x3xf32>, %arg1: tensor<?x3x6xf32>)
3741
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]])
3842
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0_0]] : f32) outs(%[[INIT]] : tensor<?x5x6xf32>) -> tensor<?x5x6xf32>
3943
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x5x3xf32>, tensor<?x3x6xf32>) outs(%[[FILLED]] : tensor<?x5x6xf32>) -> tensor<?x5x6xf32>
40-
%0 = tosa.matmul %arg0, %arg1 : (tensor<?x5x3xf32>, tensor<?x3x6xf32>) -> tensor<?x5x6xf32>
44+
%a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
45+
%b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
46+
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<?x5x3xf32>, tensor<?x3x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x5x6xf32>
4147
return %0 : tensor<?x5x6xf32>
4248
}
4349

@@ -51,7 +57,9 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x
5157
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]])
5258
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32>
5359
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x?xf32>) outs(%[[FILLED]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32>
54-
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x3xf32>, tensor<1x3x?xf32>) -> tensor<1x5x?xf32>
60+
%a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
61+
%b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
62+
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xf32>, tensor<1x3x?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x?xf32>
5563
return %0 : tensor<1x5x?xf32>
5664
}
5765

@@ -63,7 +71,9 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x
6371
// CHECK: %[[INIT:.+]] = tensor.empty()
6472
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
6573
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x?xf32>, tensor<1x?x6xf32>) outs(%[[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
66-
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x?xf32>, tensor<1x?x6xf32>) -> tensor<1x5x6xf32>
74+
%a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
75+
%b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
76+
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x?xf32>, tensor<1x?x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x6xf32>
6777
return %0 : tensor<1x5x6xf32>
6878
}
6979

@@ -77,7 +87,9 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>)
7787
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x1x1xf32>
7888
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
7989
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x1x8xf32>, tensor<1x8x1xf32>) outs(%[[FILLED]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
80-
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>
90+
%a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
91+
%b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
92+
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x1x8xf32>, tensor<1x8x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x1x1xf32>
8193
return %0 : tensor<?x1x1xf32>
8294
}
8395

0 commit comments

Comments
 (0)