@@ -51,10 +51,19 @@ def resolve_integration(
51
51
will attempt to infer it from the source_path.
52
52
:return: The name of the integration to use for exporting the model.
53
53
"""
54
- from sparseml .pytorch .image_classification .utils .helpers import (
55
- is_image_classification_model ,
56
- )
57
- from sparseml .transformers .utils .helpers import is_transformer_model
54
+ try :
55
+ from sparseml .pytorch .image_classification .utils .helpers import (
56
+ is_image_classification_model ,
57
+ )
58
+ except ImportError :
59
+ # unable to import integration, always return False
60
+ is_image_classification_model = _null_is_model
61
+
62
+ try :
63
+ from sparseml .transformers .utils .helpers import is_transformer_model
64
+ except ImportError :
65
+ # unable to import integration, always return False
66
+ is_transformer_model = _null_is_model
58
67
59
68
if (
60
69
integration == Integrations .image_classification .value
@@ -63,7 +72,6 @@ def resolve_integration(
63
72
import sparseml .pytorch .image_classification .integration_helper_functions # noqa F401
64
73
65
74
return Integrations .image_classification .value
66
-
67
75
elif integration == Integrations .transformers .value or is_transformer_model (
68
76
source_path
69
77
):
@@ -80,6 +88,12 @@ def resolve_integration(
80
88
)
81
89
82
90
91
+ def _null_is_model (* args , ** kwargs ):
92
+ # convenience function to always return False for an integration
93
+ # to be used if that integration is not importable
94
+ return False
95
+
96
+
83
97
class IntegrationHelperFunctions (RegistryMixin , BaseModel ):
84
98
"""
85
99
Registry that maps names to helper functions
@@ -88,7 +102,7 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
88
102
"""
89
103
90
104
create_model : Callable [
91
- [Union [str , Path ], ... ],
105
+ [Union [str , Path ]],
92
106
Tuple [
93
107
"torch.nn.Module" , # noqa F821
94
108
Optional [Dict [str , Any ]],
@@ -102,13 +116,13 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
102
116
"- (optionally) loaded_model_kwargs "
103
117
"(any relevant objects created along with the model)"
104
118
)
105
- create_dummy_input : Callable [... , "torch.Tensor" ] = Field ( # noqa F821
119
+ create_dummy_input : Callable [[ Any ] , "torch.Tensor" ] = Field ( # noqa F821
106
120
description = "A function that takes: "
107
121
"- appropriate arguments "
108
122
"and returns: "
109
123
"- a dummy input for the model (a torch.Tensor) "
110
124
)
111
- export : Callable [... , str ] = Field (
125
+ export : Callable [[ Any ] , str ] = Field (
112
126
description = "A function that takes: "
113
127
" - a (sparse) PyTorch model "
114
128
" - sample input data "
@@ -120,15 +134,19 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
120
134
"and returns the path to the exported model" ,
121
135
default = export_model ,
122
136
)
123
- apply_optimizations : Optional [Callable [... , None ]] = Field (
137
+ apply_optimizations : Optional [Callable [[ Any ] , None ]] = Field (
124
138
description = "A function that takes:"
125
139
" - path to the exported model"
126
140
" - names of the optimizations to apply"
127
141
" and applies the optimizations to the model" ,
128
142
)
129
143
130
144
create_data_samples : Callable [
131
- Tuple [Optional ["torch.nn.Module" ], int , Optional [Dict [str , Any ]]], # noqa F821
145
+ [
146
+ Tuple [
147
+ Optional ["torch.nn.Module" ], int , Optional [Dict [str , Any ]] # noqa: F821
148
+ ]
149
+ ],
132
150
Tuple [
133
151
List ["torch.Tensor" ], # noqa F821
134
152
Optional [List ["torch.Tensor" ]], # noqa F821
0 commit comments