Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 57a4dd0

Browse files
authored
[Export refactor] final manual testing fixes (#1948)
* [Export refactor] final manual testing fixes * review
1 parent c3c90a4 commit 57a4dd0

File tree

4 files changed

+39
-15
lines changed

4 files changed

+39
-15
lines changed

src/sparseml/export/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from .export import *

src/sparseml/export/helpers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
import logging
1616
import os
1717
import shutil
18-
from collections import OrderedDict
1918
from enum import Enum
2019
from pathlib import Path
21-
from typing import Any, Callable, Dict, List, Optional, Union
20+
from typing import Any, Callable, Dict, List, Optional, OrderedDict, Union
2221

2322
from sparseml.exporters import ExportTargets
2423

src/sparseml/integration_helper_functions.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,19 @@ def resolve_integration(
5151
will attempt to infer it from the source_path.
5252
:return: The name of the integration to use for exporting the model.
5353
"""
54-
from sparseml.pytorch.image_classification.utils.helpers import (
55-
is_image_classification_model,
56-
)
57-
from sparseml.transformers.utils.helpers import is_transformer_model
54+
try:
55+
from sparseml.pytorch.image_classification.utils.helpers import (
56+
is_image_classification_model,
57+
)
58+
except ImportError:
59+
# unable to import integration, always return False
60+
is_image_classification_model = _null_is_model
61+
62+
try:
63+
from sparseml.transformers.utils.helpers import is_transformer_model
64+
except ImportError:
65+
# unable to import integration, always return False
66+
is_transformer_model = _null_is_model
5867

5968
if (
6069
integration == Integrations.image_classification.value
@@ -63,7 +72,6 @@ def resolve_integration(
6372
import sparseml.pytorch.image_classification.integration_helper_functions # noqa F401
6473

6574
return Integrations.image_classification.value
66-
6775
elif integration == Integrations.transformers.value or is_transformer_model(
6876
source_path
6977
):
@@ -80,6 +88,12 @@ def resolve_integration(
8088
)
8189

8290

91+
def _null_is_model(*args, **kwargs):
92+
# convenience function to always return False for an integration
93+
# to be used if that integration is not importable
94+
return False
95+
96+
8397
class IntegrationHelperFunctions(RegistryMixin, BaseModel):
8498
"""
8599
Registry that maps names to helper functions
@@ -88,7 +102,7 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
88102
"""
89103

90104
create_model: Callable[
91-
[Union[str, Path], ...],
105+
[Union[str, Path]],
92106
Tuple[
93107
"torch.nn.Module", # noqa F821
94108
Optional[Dict[str, Any]],
@@ -102,13 +116,13 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
102116
"- (optionally) loaded_model_kwargs "
103117
"(any relevant objects created along with the model)"
104118
)
105-
create_dummy_input: Callable[..., "torch.Tensor"] = Field( # noqa F821
119+
create_dummy_input: Callable[[Any], "torch.Tensor"] = Field( # noqa F821
106120
description="A function that takes: "
107121
"- appropriate arguments "
108122
"and returns: "
109123
"- a dummy input for the model (a torch.Tensor) "
110124
)
111-
export: Callable[..., str] = Field(
125+
export: Callable[[Any], str] = Field(
112126
description="A function that takes: "
113127
" - a (sparse) PyTorch model "
114128
" - sample input data "
@@ -120,15 +134,19 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
120134
"and returns the path to the exported model",
121135
default=export_model,
122136
)
123-
apply_optimizations: Optional[Callable[..., None]] = Field(
137+
apply_optimizations: Optional[Callable[[Any], None]] = Field(
124138
description="A function that takes:"
125139
" - path to the exported model"
126140
" - names of the optimizations to apply"
127141
" and applies the optimizations to the model",
128142
)
129143

130144
create_data_samples: Callable[
131-
Tuple[Optional["torch.nn.Module"], int, Optional[Dict[str, Any]]], # noqa F821
145+
[
146+
Tuple[
147+
Optional["torch.nn.Module"], int, Optional[Dict[str, Any]] # noqa: F821
148+
]
149+
],
132150
Tuple[
133151
List["torch.Tensor"], # noqa F821
134152
Optional[List["torch.Tensor"]], # noqa F821

src/sparseml/transformers/sparsification/trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -694,9 +694,14 @@ def _add_tensorboard_logger_if_available(self):
694694
self.args, log_dir=self.args.logging_dir
695695
)
696696

697-
self.logger_manager.add_logger(
698-
TensorBoardLogger(writer=tensorboard_callback.tb_writer)
699-
)
697+
try:
698+
self.logger_manager.add_logger(
699+
TensorBoardLogger(writer=tensorboard_callback.tb_writer)
700+
)
701+
except (ImportError, ModuleNotFoundError):
702+
_LOGGER.info(
703+
f"Unable to import tensorboard - running without tensorboard logging"
704+
)
700705

701706
def _get_fake_dataloader(
702707
self,

0 commit comments

Comments
 (0)