Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 3da5f23

Browse files
dbogunowiczbfineran
andcommitted
[Export Refactor] Prepare the module to be more general (before including transformers) (#1908)
* adapt the export script to handle transformers * Update src/sparseml/pytorch/image_classification/integration_helper_functions.py * Delete tests/sparseml/export/transformers/__init__.py * Delete tests/sparseml/export/transformers/test_generative_transformers.py * Delete tests/sparseml/export/transformers/test_transformers.py * Update src/sparseml/export/export.py Co-authored-by: Benjamin Fineran <[email protected]> * addressing review comments * [Export Refactor] Export `transformers` (#1909) * cleanup * Delete src/sparseml/transformers/integration_helper_functions_generative.py * Delete src/sparseml/transformers/utils/optimizations.py * Delete tests/sparseml/export/transformers/test_generative_transformers.py * Delete tests/sparseml/transformers/test_integration_helper_functions_generative.py * addressing PR reviews * [Export Refactor] Export generative transformers(#1910) * make tests green, remove using task to resolve the integration type * fix all the tests after the merge, make integration resolution independent of the task name * fold generative transformers into transformer helper functions * complete tests for export_data.py * Update src/sparseml/export/export.py * add tests that confirms that kv cache injection has been added * move applying optimizations into integration helper functions --------- Co-authored-by: Benjamin Fineran <[email protected]>
1 parent 627ddd6 commit 3da5f23

25 files changed

+1626
-323
lines changed

src/sparseml/export/export.py

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,21 @@
1313
# limitations under the License.
1414

1515
import logging
16+
import os
1617
from pathlib import Path
1718
from typing import Any, List, Optional, Union
1819

1920
from sparseml.export.export_data import export_data_samples
2021
from sparseml.export.helpers import (
2122
AVAILABLE_DEPLOYMENT_TARGETS,
2223
ONNX_MODEL_NAME,
23-
apply_optimizations,
2424
create_deployment_folder,
25+
create_export_kwargs,
2526
)
2627
from sparseml.export.validators import validate_correctness as validate_correctness_
2728
from sparseml.export.validators import validate_structure as validate_structure_
2829
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
29-
from sparseml.pytorch.utils.helpers import default_device
30+
from sparseml.pytorch.utils.helpers import default_device, use_single_gpu
3031
from src.sparseml.integration_helper_functions import (
3132
IntegrationHelperFunctions,
3233
resolve_integration,
@@ -44,14 +45,15 @@ def export(
4445
opset: int = TORCH_DEFAULT_ONNX_OPSET,
4546
single_graph_file: bool = True,
4647
num_export_samples: int = 0,
48+
batch_size: int = 1,
4749
deployment_directory_name: str = "deployment",
4850
device: str = "auto",
4951
graph_optimizations: Union[str, List[str], None] = "all",
5052
validate_correctness: bool = False,
5153
validate_structure: bool = True,
5254
integration: Optional[str] = None,
5355
sample_data: Optional[Any] = None,
54-
batch_size: Optional[int] = None,
56+
task: Optional[str] = None,
5557
**kwargs,
5658
):
5759
"""
@@ -84,6 +86,8 @@ def export(
8486
file. Defaults to True.
8587
:param num_export_samples: The number of samples to create for
8688
the exported model. Defaults to 0.
89+
:param batch_size: The batch size to use for exporting the data.
90+
Defaults to None.
8791
:param deployment_directory_name: The name of the deployment
8892
directory to create for the exported model. Thus, the exported
8993
model will be saved to `target_path/deployment_directory_name`.
@@ -102,7 +106,7 @@ def export(
102106
:param sample_data: Optional sample data to use for exporting
103107
the model. If not provided, a dummy input will be created
104108
for the model. Defaults to None.
105-
:param batch_size: The batch size to use for exporting the data.
109+
:param task: Optional task to use for exporting the model.
106110
Defaults to None.
107111
"""
108112

@@ -112,6 +116,7 @@ def export(
112116

113117
# choose the appropriate device
114118
device = default_device() if device == "auto" else device
119+
device = use_single_gpu(device) if "cuda" in device else device
115120

116121
# assert the valid deployment target
117122
if deployment_target not in AVAILABLE_DEPLOYMENT_TARGETS:
@@ -126,69 +131,55 @@ def export(
126131
_LOGGER.info(f"Starting export for {integration} model...")
127132

128133
helper_functions: IntegrationHelperFunctions = (
129-
IntegrationHelperFunctions.load_from_registry(integration)
134+
IntegrationHelperFunctions.load_from_registry(integration, task=task)
130135
)
131136

132137
_LOGGER.info("Creating model for the export...")
133-
model, validation_dataloader = helper_functions.create_model(
134-
source_path, batch_size, device, **kwargs
138+
139+
# loaded_model_kwargs may include any objects
140+
# that were created along with the model and are needed
141+
# for the export
142+
model, loaded_model_kwargs = helper_functions.create_model(
143+
source_path, device=device, task=task, batch_size=batch_size, **kwargs
135144
)
136145

137-
if validation_dataloader:
138-
_LOGGER.info("Created validation dataloader for the export")
139-
else:
140-
_LOGGER.warning(
141-
"Failed to create validation dataloader for the export. "
142-
"Will be using the dummy (or user-provided) data instead "
143-
"and will be not able to export samples or validate the model "
144-
"correctness."
146+
if loaded_model_kwargs:
147+
_LOGGER.info(
148+
"Created additional items that will "
149+
f"be used for the export: {list(loaded_model_kwargs.keys())}"
145150
)
146151

147152
sample_data = (
148-
helper_functions.create_dummy_input(
149-
validation_dataloader=validation_dataloader, **kwargs
150-
)
153+
helper_functions.create_dummy_input(**loaded_model_kwargs, **kwargs)
151154
if sample_data is None
152155
else sample_data
153156
)
154157

155158
_LOGGER.info(f"Exporting {onnx_model_name} to {target_path}...")
159+
160+
export_kwargs = create_export_kwargs(loaded_model_kwargs)
161+
156162
onnx_file_path = helper_functions.export(
157163
model=model,
158164
sample_data=sample_data,
159165
target_path=target_path,
160166
onnx_model_name=onnx_model_name,
161167
deployment_target=deployment_target,
162168
opset=opset,
169+
**export_kwargs,
163170
)
164-
_LOGGER.info(f"Successfully exported {onnx_model_name} to {target_path}...")
165-
166-
_LOGGER.info(
167-
f"Applying optimizations: {graph_optimizations} to the exported model..."
168-
)
169-
apply_optimizations(
170-
onnx_file_path=onnx_file_path,
171-
target_optimizations=graph_optimizations,
172-
available_optimizations=helper_functions.graph_optimizations,
173-
single_graph_file=single_graph_file,
174-
)
171+
_LOGGER.info(f"Successfully exported {onnx_model_name} to {onnx_file_path}...")
175172

176173
if num_export_samples:
177174
_LOGGER.info(f"Exporting {num_export_samples} samples...")
178-
if not validation_dataloader:
179-
raise ValueError(
180-
"To export sample inputs/outputs a data loader is needed. "
181-
"To return a data loader provide the appropriate, integration-specific "
182-
"arguments to `create_model` function"
183-
)
184175
(
185176
input_samples,
186177
output_samples,
187178
label_samples,
188179
) = helper_functions.create_data_samples(
189180
num_samples=num_export_samples,
190-
data_loader=validation_dataloader,
191181
model=model,
182+
**loaded_model_kwargs,
192183
)
193184
export_data_samples(
194185
input_samples=input_samples,
@@ -207,16 +198,29 @@ def export(
207198
source_path=source_path,
208199
target_path=target_path,
209200
deployment_directory_name=deployment_directory_name,
210-
deployment_directory_files=helper_functions.deployment_directory_structure,
201+
deployment_directory_files_mandatory=helper_functions.deployment_directory_files_mandatory, # noqa: E501
202+
deployment_directory_files_optional=helper_functions.deployment_directory_files_optional, # noqa: E501
211203
onnx_model_name=onnx_model_name,
212204
)
205+
206+
_LOGGER.info(
207+
f"Applying optimizations: {graph_optimizations} to the exported model..."
208+
)
209+
if helper_functions.apply_optimizations is not None:
210+
helper_functions.apply_optimizations(
211+
exported_file_path=os.path.join(deployment_path, onnx_model_name),
212+
optimizations=graph_optimizations,
213+
single_graph_file=single_graph_file,
214+
)
215+
213216
if validate_structure:
214217
_LOGGER.info("Validating model structure...")
215218
validate_structure_(
216219
target_path=target_path,
217220
deployment_directory_name=deployment_directory_name,
218221
onnx_model_name=onnx_model_name,
219-
deployment_directory_files=helper_functions.deployment_directory_structure,
222+
deployment_directory_files_mandatory=helper_functions.deployment_directory_files_mandatory, # noqa: E501
223+
deployment_directory_files_optional=helper_functions.deployment_directory_files_optional, # noqa: E501
220224
)
221225

222226
if validate_correctness:

src/sparseml/export/export_data.py

Lines changed: 109 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import tarfile
1919
from enum import Enum
2020
from pathlib import Path
21-
from typing import List, Optional, Tuple, Union
21+
from typing import Any, Dict, List, Optional, Tuple, Union
2222

2323
import torch
2424
from tqdm import tqdm
@@ -46,47 +46,11 @@ class InputsNames(Enum):
4646
filename = "inp"
4747

4848

49-
def create_data_samples(
50-
data_loader: torch.utils.data.DataLoader,
51-
model: Optional[torch.nn.Module] = None,
52-
num_samples: int = 1,
53-
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
54-
"""
55-
Fetch a batch of samples from the data loader and return the inputs and outputs
56-
57-
:param data_loader: The data loader to get a batch of inputs/outputs from.
58-
:param model: The model to run the inputs through to get the outputs.
59-
If None, the outputs will be an empty list.
60-
:param num_samples: The number of samples to generate. Defaults to 1
61-
:return: The inputs and outputs as lists of torch tensors
62-
"""
63-
inputs, outputs, labels = [], [], []
64-
if model is None:
65-
_LOGGER.warning("The model is None. The list of outputs will be empty")
66-
for batch_num, (inputs_, labels_) in tqdm(enumerate(data_loader)):
67-
if batch_num == num_samples:
68-
break
69-
if model:
70-
outputs_ = model(inputs_)
71-
if isinstance(outputs_, tuple):
72-
# outputs_ contains (logits, softmax)
73-
outputs_ = outputs_[0]
74-
outputs.append(outputs_)
75-
inputs.append(inputs_)
76-
labels.append(
77-
torch.IntTensor([labels_])
78-
if not isinstance(labels_, torch.Tensor)
79-
else labels_
80-
)
81-
82-
return inputs, outputs, labels
83-
84-
8549
def export_data_samples(
8650
target_path: Union[Path, str],
87-
input_samples: Optional[List["torch.Tensor"]] = None, # noqa F821
88-
output_samples: Optional[List["torch.Tensor"]] = None, # noqa F821
89-
label_samples: Optional[List["torch.Tensor"]] = None, # noqa F821
51+
input_samples: Optional[List[Any]] = None,
52+
output_samples: Optional[List[Any]] = None,
53+
label_samples: Optional[List[Any]] = None,
9054
as_tar: bool = False,
9155
):
9256
"""
@@ -116,6 +80,7 @@ def export_data_samples(
11680
11781
:param input_samples: The input samples to save.
11882
:param output_samples: The output samples to save.
83+
:param label_samples: The label samples to save.
11984
:param target_path: The path to save the samples to.
12085
:param as_tar: Whether to save the samples as tar files.
12186
"""
@@ -124,16 +89,21 @@ def export_data_samples(
12489
[input_samples, output_samples, label_samples],
12590
[InputsNames, OutputsNames, LabelNames],
12691
):
127-
if samples is not None:
92+
if len(samples) > 0:
12893
_LOGGER.info(f"Exporting {names.basename.value} to {target_path}...")
129-
export_data_sample(samples, names, target_path, as_tar)
94+
break_batch = isinstance(samples[0], dict)
95+
export_data_sample(samples, names, target_path, as_tar, break_batch)
13096
_LOGGER.info(
13197
f"Successfully exported {names.basename.value} to {target_path}!"
13298
)
13399

134100

135101
def export_data_sample(
136-
samples, names: Enum, target_path: Union[Path, str], as_tar: bool = False
102+
samples,
103+
names: Enum,
104+
target_path: Union[Path, str],
105+
as_tar: bool = False,
106+
break_batch=False,
137107
):
138108

139109
samples = tensors_to_device(samples, "cpu")
@@ -142,9 +112,105 @@ def export_data_sample(
142112
tensors=samples,
143113
export_dir=os.path.join(target_path, names.basename.value),
144114
name_prefix=names.filename.value,
115+
break_batch=break_batch,
145116
)
146117
if as_tar:
147118
folder_path = os.path.join(target_path, names.basename.value)
148119
with tarfile.open(folder_path + ".tar.gz", "w:gz") as tar:
149120
tar.add(folder_path, arcname=os.path.basename(folder_path))
150121
shutil.rmtree(folder_path)
122+
123+
124+
def create_data_samples(
125+
data_loader: torch.utils.data.DataLoader,
126+
model: Optional[torch.nn.Module] = None,
127+
num_samples: int = 1,
128+
) -> Tuple[List[Any], List[Any], List[Any]]:
129+
"""
130+
Fetch a batch of samples from the data loader and return the inputs and outputs
131+
132+
:param data_loader: The data loader to get a batch of inputs/outputs from.
133+
:param model: The model to run the inputs through to get the outputs.
134+
If None, the outputs will be an empty list.
135+
:param num_samples: The number of samples to generate. Defaults to 1
136+
:return: The inputs and outputs as lists of torch tensors
137+
"""
138+
inputs, outputs, labels = [], [], []
139+
if model is None:
140+
_LOGGER.warning("The model is None. The list of outputs will be empty")
141+
142+
for batch_num, data in tqdm(enumerate(data_loader)):
143+
if batch_num == num_samples:
144+
break
145+
if isinstance(data, dict):
146+
inputs_, labels_, outputs_ = run_inference_with_dict_data(
147+
data=data, model=model
148+
)
149+
elif isinstance(data, (list, tuple)):
150+
inputs_, labels_, outputs_ = run_inference_with_tuple_or_list_data(
151+
data=data, model=model
152+
)
153+
else:
154+
raise ValueError(
155+
f"Data type {type(data)} is not supported. "
156+
f"Only dict and tuple are supported"
157+
)
158+
159+
inputs.append(inputs_)
160+
if outputs_ is not None:
161+
outputs.append(outputs_)
162+
if labels_ is not None:
163+
labels.append(
164+
torch.IntTensor([labels_])
165+
if not isinstance(labels_, torch.Tensor)
166+
else labels_
167+
)
168+
169+
return inputs, outputs, labels
170+
171+
172+
def run_inference_with_dict_data(
173+
data: Dict[str, Any], model: Optional[torch.nn.Module] = None
174+
) -> Tuple[Dict[str, Any], Any, Optional[Dict[str, Any]]]:
175+
"""
176+
Run inference on a model by inferring the appropriate
177+
inputs from the dictionary input data.
178+
179+
180+
:param data: The data to run inference on
181+
:param model: The model to run inference on (optional)
182+
:return: The inputs, labels and outputs
183+
"""
184+
labels = None
185+
if model is None:
186+
output = None
187+
188+
else:
189+
inputs = {key: value.to(model.device) for key, value in data.items()}
190+
output_vals = model(**inputs)
191+
output = {
192+
name: torch.squeeze(val).detach().to("cpu")
193+
for name, val in output_vals.items()
194+
}
195+
inputs = {key: value.to("cpu") for key, value in data.items()}
196+
return inputs, labels, output
197+
198+
199+
def run_inference_with_tuple_or_list_data(
200+
data: Tuple[Any, Any], model: Optional[torch.nn.Module] = None
201+
) -> Tuple[torch.Tensor, Any, Optional[torch.Tensor]]:
202+
"""
203+
Run inference on a model by inferring the appropriate
204+
inputs from the tuple input data.
205+
206+
:param inputs: The data to run inference on
207+
:param model: The model to run inference on (optional)
208+
:return: The inputs, labels and outputs
209+
"""
210+
# assume that
211+
inputs, labels = data
212+
outputs = model(inputs) if model else None
213+
if isinstance(outputs, tuple):
214+
# outputs_ contains (logits, softmax)
215+
outputs = outputs[0]
216+
return inputs, labels, outputs

0 commit comments

Comments
 (0)