Skip to content

Commit 70567a4

Browse files
authored
[js/web] use ApiTensor insteadof onnxjs Tensor in TensorResultValidator (#19358)
### Description use ApiTensor insteadof onnxjs Tensor in TensorResultValidator. Make test runner less depend on onnxjs classes.
1 parent 3fe2c13 commit 70567a4

File tree

2 files changed

+13
-17
lines changed

2 files changed

+13
-17
lines changed

js/web/test/test-runner.ts

+10-16
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,6 @@ const ONNXRUNTIME_THRESHOLD_RELATIVE_ERROR = 1.00001;
3939
*/
4040
const now = (typeof performance !== 'undefined' && performance.now) ? () => performance.now() : Date.now;
4141

42-
function toInternalTensor(tensor: ort.Tensor): Tensor {
43-
return new Tensor(
44-
tensor.dims, tensor.type as Tensor.DataType, undefined, undefined, tensor.data as Tensor.NumberType);
45-
}
4642
function fromInternalTensor(tensor: Tensor): ort.Tensor {
4743
return new ort.Tensor(tensor.type, tensor.data as ort.Tensor.DataType, tensor.dims);
4844
}
@@ -330,6 +326,10 @@ export class TensorResultValidator {
330326
}
331327

332328
checkTensorResult(actual: Tensor[], expected: Tensor[]): void {
329+
this.checkApiTensorResult(actual.map(fromInternalTensor), expected.map(fromInternalTensor));
330+
}
331+
332+
checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void {
333333
// check output size
334334
expect(actual.length, 'size of output tensors').to.equal(expected.length);
335335

@@ -347,10 +347,6 @@ export class TensorResultValidator {
347347
}
348348
}
349349

350-
checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void {
351-
this.checkTensorResult(actual.map(toInternalTensor), expected.map(toInternalTensor));
352-
}
353-
354350
checkNamedTensorResult(actual: Record<string, ort.Tensor>, expected: Test.NamedTensor[]): void {
355351
// check output size
356352
expect(Object.getOwnPropertyNames(actual).length, 'size of output tensors').to.equal(expected.length);
@@ -364,7 +360,7 @@ export class TensorResultValidator {
364360
}
365361

366362
// This function check whether 2 tensors should be considered as 'match' or not
367-
areEqual(actual: Tensor, expected: Tensor): boolean {
363+
areEqual(actual: ort.Tensor, expected: ort.Tensor): boolean {
368364
if (!actual || !expected) {
369365
return false;
370366
}
@@ -392,13 +388,13 @@ export class TensorResultValidator {
392388

393389
switch (actualType) {
394390
case 'string':
395-
return this.strictEqual(actual.stringData, expected.stringData);
391+
return this.strictEqual(actual.data, expected.data);
396392

397393
case 'float32':
398394
case 'float64':
399395
return this.floatEqual(
400-
actual.numberData as number[] | Float32Array | Float64Array,
401-
expected.numberData as number[] | Float32Array | Float64Array);
396+
actual.data as number[] | Float32Array | Float64Array,
397+
expected.data as number[] | Float32Array | Float64Array);
402398

403399
case 'uint8':
404400
case 'int8':
@@ -409,10 +405,8 @@ export class TensorResultValidator {
409405
case 'int64':
410406
case 'bool':
411407
return TensorResultValidator.integerEqual(
412-
actual.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array |
413-
Int32Array,
414-
expected.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array |
415-
Int32Array);
408+
actual.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array,
409+
expected.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array);
416410

417411
default:
418412
throw new Error('type not implemented or not supported');

js/web/test/unittests/backends/webgl/test-conv-new.ts

+3-1
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,9 @@ describe('New Conv tests', () => {
893893
const expected = cpuConv(
894894
inputTensor, kernelTensor, biasTensor, testData.autoPad, testData.dilations, testData.pads,
895895
testData.strides);
896-
if (!validator.areEqual(actual, expected)) {
896+
try {
897+
validator.checkTensorResult([actual], [expected]);
898+
} catch {
897899
console.log(actual.dims, `[${actual.numberData.slice(0, 20).join(',')},...]`);
898900
console.log(expected.dims, `[${expected.numberData.slice(0, 20).join(',')},...]`);
899901
throw new Error('Expected and Actual did not match');

0 commit comments

Comments
 (0)