@@ -39,10 +39,6 @@ const ONNXRUNTIME_THRESHOLD_RELATIVE_ERROR = 1.00001;
39
39
*/
40
40
const now = ( typeof performance !== 'undefined' && performance . now ) ? ( ) => performance . now ( ) : Date . now ;
41
41
42
- function toInternalTensor ( tensor : ort . Tensor ) : Tensor {
43
- return new Tensor (
44
- tensor . dims , tensor . type as Tensor . DataType , undefined , undefined , tensor . data as Tensor . NumberType ) ;
45
- }
46
42
function fromInternalTensor ( tensor : Tensor ) : ort . Tensor {
47
43
return new ort . Tensor ( tensor . type , tensor . data as ort . Tensor . DataType , tensor . dims ) ;
48
44
}
@@ -330,6 +326,10 @@ export class TensorResultValidator {
330
326
}
331
327
332
328
checkTensorResult ( actual : Tensor [ ] , expected : Tensor [ ] ) : void {
329
+ this . checkApiTensorResult ( actual . map ( fromInternalTensor ) , expected . map ( fromInternalTensor ) ) ;
330
+ }
331
+
332
+ checkApiTensorResult ( actual : ort . Tensor [ ] , expected : ort . Tensor [ ] ) : void {
333
333
// check output size
334
334
expect ( actual . length , 'size of output tensors' ) . to . equal ( expected . length ) ;
335
335
@@ -347,10 +347,6 @@ export class TensorResultValidator {
347
347
}
348
348
}
349
349
350
- checkApiTensorResult ( actual : ort . Tensor [ ] , expected : ort . Tensor [ ] ) : void {
351
- this . checkTensorResult ( actual . map ( toInternalTensor ) , expected . map ( toInternalTensor ) ) ;
352
- }
353
-
354
350
checkNamedTensorResult ( actual : Record < string , ort . Tensor > , expected : Test . NamedTensor [ ] ) : void {
355
351
// check output size
356
352
expect ( Object . getOwnPropertyNames ( actual ) . length , 'size of output tensors' ) . to . equal ( expected . length ) ;
@@ -364,7 +360,7 @@ export class TensorResultValidator {
364
360
}
365
361
366
362
// This function check whether 2 tensors should be considered as 'match' or not
367
- areEqual ( actual : Tensor , expected : Tensor ) : boolean {
363
+ areEqual ( actual : ort . Tensor , expected : ort . Tensor ) : boolean {
368
364
if ( ! actual || ! expected ) {
369
365
return false ;
370
366
}
@@ -392,13 +388,13 @@ export class TensorResultValidator {
392
388
393
389
switch ( actualType ) {
394
390
case 'string' :
395
- return this . strictEqual ( actual . stringData , expected . stringData ) ;
391
+ return this . strictEqual ( actual . data , expected . data ) ;
396
392
397
393
case 'float32' :
398
394
case 'float64' :
399
395
return this . floatEqual (
400
- actual . numberData as number [ ] | Float32Array | Float64Array ,
401
- expected . numberData as number [ ] | Float32Array | Float64Array ) ;
396
+ actual . data as number [ ] | Float32Array | Float64Array ,
397
+ expected . data as number [ ] | Float32Array | Float64Array ) ;
402
398
403
399
case 'uint8' :
404
400
case 'int8' :
@@ -409,10 +405,8 @@ export class TensorResultValidator {
409
405
case 'int64' :
410
406
case 'bool' :
411
407
return TensorResultValidator . integerEqual (
412
- actual . numberData as number [ ] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array |
413
- Int32Array ,
414
- expected . numberData as number [ ] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array |
415
- Int32Array ) ;
408
+ actual . data as number [ ] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array ,
409
+ expected . data as number [ ] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array ) ;
416
410
417
411
default :
418
412
throw new Error ( 'type not implemented or not supported' ) ;
0 commit comments