diff --git a/src/deepsparse/utils/onnx.py b/src/deepsparse/utils/onnx.py index 5ce98ce6a0..a73cddc24b 100644 --- a/src/deepsparse/utils/onnx.py +++ b/src/deepsparse/utils/onnx.py @@ -60,15 +60,16 @@ def save_onnx_to_temp_files(model: onnx.ModelProto, with_external_data=False) -> :param model: The onnx model to save to temporary directory :param with_external_data: Whether to save external data to a separate file """ - shaped_model = tempfile.NamedTemporaryFile(mode="w", delete=False) - _LOGGER.info(f"Saving model to temporary directory: {tempfile.tempdir}") + + shaped_model = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False, mode="w") + _LOGGER.info(f"Saving model to temporary file: {shaped_model.name}") if with_external_data: - external_data = os.path.join( - tempfile.tempdir, next(tempfile._get_candidate_names()) + external_data = tempfile.NamedTemporaryFile( + suffix=".data", delete=False, mode="w" ) - has_external_data = save_onnx(model, shaped_model.name, external_data) - _LOGGER.info(f"Saving external data to temporary directory: {external_data}") + _LOGGER.info(f"Saving external data to temporary file: {external_data.name}") + has_external_data = save_onnx(model, shaped_model.name, external_data.name) else: has_external_data = save_onnx(model, shaped_model.name) try: @@ -236,9 +237,13 @@ def override_onnx_batch_size( save_onnx(model, onnx_filepath) yield onnx_filepath else: - return save_onnx_to_temp_files(model, with_external_data=not inplace) + with save_onnx_to_temp_files( + model, with_external_data=not inplace + ) as temp_file: + yield temp_file +@contextlib.contextmanager def override_onnx_input_shapes( onnx_filepath: str, input_shapes: Union[List[int], List[List[int]]], @@ -300,16 +305,19 @@ def override_onnx_input_shapes( if inplace: _LOGGER.info( - "Overwriting in-place the input shapes of the model " f"at {onnx_filepath}" + f"Overwriting in-place the input shapes of the model at {onnx_filepath}" ) onnx.save(model, onnx_filepath) - return onnx_filepath + yield onnx_filepath else: _LOGGER.info( f"Saving the input shapes of the model at {onnx_filepath} " f"to a temporary file" ) - return save_onnx_to_temp_files(model, with_external_data=not inplace) + with save_onnx_to_temp_files( + model, with_external_data=not inplace + ) as temp_file: + yield temp_file def truncate_onnx_model( diff --git a/tests/deepsparse/utils/__init__.py b/tests/deepsparse/utils/__init__.py new file mode 100644 index 0000000000..0c44f887a4 --- /dev/null +++ b/tests/deepsparse/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/deepsparse/utils/onnx.py b/tests/deepsparse/utils/onnx.py new file mode 100644 index 0000000000..34b95c9a47 --- /dev/null +++ b/tests/deepsparse/utils/onnx.py @@ -0,0 +1,71 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import onnx + +import pytest +from deepsparse.utils.onnx import override_onnx_batch_size, override_onnx_input_shapes +from sparsezoo import Model + + +@pytest.mark.parametrize( + "test_model, batch_size", + [ + ( + "zoo:cv/classification/mobilenet_v1-1.0/pytorch/sparseml/imagenet/base-none", # noqa: E501 + 10, + ) + ], + scope="function", +) +@pytest.mark.parametrize("inplace", [True, False], scope="function") +def test_override_onnx_batch_size(test_model, batch_size, inplace): + onnx_file_path = Model(test_model).onnx_model.path + # Override the batch size of the ONNX model + with override_onnx_batch_size( + onnx_file_path, batch_size, inplace=inplace + ) as modified_model_path: + # Load the modified ONNX model + modified_model = onnx.load(modified_model_path) + assert ( + modified_model.graph.input[0].type.tensor_type.shape.dim[0].dim_value + == batch_size + ) + + +@pytest.mark.parametrize( + "test_model, input_shapes", + [ + ( + "zoo:cv/classification/mobilenet_v1-1.0/pytorch/sparseml/imagenet/base-none", # noqa: E501 + [10, 224, 224, 3], + ) + ], + scope="function", +) +@pytest.mark.parametrize("inplace", [True, False], scope="function") +def test_override_onnx_input_shapes(test_model, input_shapes, inplace): + onnx_file_path = Model(test_model).onnx_model.path + # Override the batch size of the ONNX model + with override_onnx_input_shapes( + onnx_file_path, input_shapes, inplace=inplace + ) as modified_model_path: + # Load the modified ONNX model + modified_model = onnx.load(modified_model_path) + new_input_shapes = [ + dim.dim_value + for dim in modified_model.graph.input[0].type.tensor_type.shape.dim + ] + assert new_input_shapes == input_shapes