diff --git a/tests/utils/engine_mocking.py b/tests/utils/engine_mocking.py index a56783a287..cef0b60164 100644 --- a/tests/utils/engine_mocking.py +++ b/tests/utils/engine_mocking.py @@ -103,7 +103,9 @@ def __init__( with override_onnx_batch_size( model_path, batch_size, inplace=True ) as batched_model_path: - session = ort.InferenceSession(batched_model_path) + session = ort.InferenceSession( + batched_model_path, providers=["CPUExecutionProvider"] + ) self.input_descriptors = list(map(_to_descriptor, session.get_inputs())) self.output_descriptors = list(map(_to_descriptor, session.get_outputs()))