18
18
import tarfile
19
19
from enum import Enum
20
20
from pathlib import Path
21
- from typing import List , Optional , Tuple , Union
21
+ from typing import Any , Dict , List , Optional , Tuple , Union
22
22
23
23
import torch
24
24
from tqdm import tqdm
@@ -46,47 +46,11 @@ class InputsNames(Enum):
46
46
filename = "inp"
47
47
48
48
49
- def create_data_samples (
50
- data_loader : torch .utils .data .DataLoader ,
51
- model : Optional [torch .nn .Module ] = None ,
52
- num_samples : int = 1 ,
53
- ) -> Tuple [List [torch .Tensor ], List [torch .Tensor ], List [torch .Tensor ]]:
54
- """
55
- Fetch a batch of samples from the data loader and return the inputs and outputs
56
-
57
- :param data_loader: The data loader to get a batch of inputs/outputs from.
58
- :param model: The model to run the inputs through to get the outputs.
59
- If None, the outputs will be an empty list.
60
- :param num_samples: The number of samples to generate. Defaults to 1
61
- :return: The inputs and outputs as lists of torch tensors
62
- """
63
- inputs , outputs , labels = [], [], []
64
- if model is None :
65
- _LOGGER .warning ("The model is None. The list of outputs will be empty" )
66
- for batch_num , (inputs_ , labels_ ) in tqdm (enumerate (data_loader )):
67
- if batch_num == num_samples :
68
- break
69
- if model :
70
- outputs_ = model (inputs_ )
71
- if isinstance (outputs_ , tuple ):
72
- # outputs_ contains (logits, softmax)
73
- outputs_ = outputs_ [0 ]
74
- outputs .append (outputs_ )
75
- inputs .append (inputs_ )
76
- labels .append (
77
- torch .IntTensor ([labels_ ])
78
- if not isinstance (labels_ , torch .Tensor )
79
- else labels_
80
- )
81
-
82
- return inputs , outputs , labels
83
-
84
-
85
49
def export_data_samples (
86
50
target_path : Union [Path , str ],
87
- input_samples : Optional [List ["torch.Tensor" ]] = None , # noqa F821
88
- output_samples : Optional [List ["torch.Tensor" ]] = None , # noqa F821
89
- label_samples : Optional [List ["torch.Tensor" ]] = None , # noqa F821
51
+ input_samples : Optional [List [Any ]] = None ,
52
+ output_samples : Optional [List [Any ]] = None ,
53
+ label_samples : Optional [List [Any ]] = None ,
90
54
as_tar : bool = False ,
91
55
):
92
56
"""
@@ -116,6 +80,7 @@ def export_data_samples(
116
80
117
81
:param input_samples: The input samples to save.
118
82
:param output_samples: The output samples to save.
83
+ :param label_samples: The label samples to save.
119
84
:param target_path: The path to save the samples to.
120
85
:param as_tar: Whether to save the samples as tar files.
121
86
"""
@@ -124,16 +89,21 @@ def export_data_samples(
124
89
[input_samples , output_samples , label_samples ],
125
90
[InputsNames , OutputsNames , LabelNames ],
126
91
):
127
- if samples is not None :
92
+ if len ( samples ) > 0 :
128
93
_LOGGER .info (f"Exporting { names .basename .value } to { target_path } ..." )
129
- export_data_sample (samples , names , target_path , as_tar )
94
+ break_batch = isinstance (samples [0 ], dict )
95
+ export_data_sample (samples , names , target_path , as_tar , break_batch )
130
96
_LOGGER .info (
131
97
f"Successfully exported { names .basename .value } to { target_path } !"
132
98
)
133
99
134
100
135
101
def export_data_sample (
136
- samples , names : Enum , target_path : Union [Path , str ], as_tar : bool = False
102
+ samples ,
103
+ names : Enum ,
104
+ target_path : Union [Path , str ],
105
+ as_tar : bool = False ,
106
+ break_batch = False ,
137
107
):
138
108
139
109
samples = tensors_to_device (samples , "cpu" )
@@ -142,9 +112,105 @@ def export_data_sample(
142
112
tensors = samples ,
143
113
export_dir = os .path .join (target_path , names .basename .value ),
144
114
name_prefix = names .filename .value ,
115
+ break_batch = break_batch ,
145
116
)
146
117
if as_tar :
147
118
folder_path = os .path .join (target_path , names .basename .value )
148
119
with tarfile .open (folder_path + ".tar.gz" , "w:gz" ) as tar :
149
120
tar .add (folder_path , arcname = os .path .basename (folder_path ))
150
121
shutil .rmtree (folder_path )
122
+
123
+
124
+ def create_data_samples (
125
+ data_loader : torch .utils .data .DataLoader ,
126
+ model : Optional [torch .nn .Module ] = None ,
127
+ num_samples : int = 1 ,
128
+ ) -> Tuple [List [Any ], List [Any ], List [Any ]]:
129
+ """
130
+ Fetch a batch of samples from the data loader and return the inputs and outputs
131
+
132
+ :param data_loader: The data loader to get a batch of inputs/outputs from.
133
+ :param model: The model to run the inputs through to get the outputs.
134
+ If None, the outputs will be an empty list.
135
+ :param num_samples: The number of samples to generate. Defaults to 1
136
+ :return: The inputs and outputs as lists of torch tensors
137
+ """
138
+ inputs , outputs , labels = [], [], []
139
+ if model is None :
140
+ _LOGGER .warning ("The model is None. The list of outputs will be empty" )
141
+
142
+ for batch_num , data in tqdm (enumerate (data_loader )):
143
+ if batch_num == num_samples :
144
+ break
145
+ if isinstance (data , dict ):
146
+ inputs_ , labels_ , outputs_ = run_inference_with_dict_data (
147
+ data = data , model = model
148
+ )
149
+ elif isinstance (data , (list , tuple )):
150
+ inputs_ , labels_ , outputs_ = run_inference_with_tuple_or_list_data (
151
+ data = data , model = model
152
+ )
153
+ else :
154
+ raise ValueError (
155
+ f"Data type { type (data )} is not supported. "
156
+ f"Only dict and tuple are supported"
157
+ )
158
+
159
+ inputs .append (inputs_ )
160
+ if outputs_ is not None :
161
+ outputs .append (outputs_ )
162
+ if labels_ is not None :
163
+ labels .append (
164
+ torch .IntTensor ([labels_ ])
165
+ if not isinstance (labels_ , torch .Tensor )
166
+ else labels_
167
+ )
168
+
169
+ return inputs , outputs , labels
170
+
171
+
172
+ def run_inference_with_dict_data (
173
+ data : Dict [str , Any ], model : Optional [torch .nn .Module ] = None
174
+ ) -> Tuple [Dict [str , Any ], Any , Optional [Dict [str , Any ]]]:
175
+ """
176
+ Run inference on a model by inferring the appropriate
177
+ inputs from the dictionary input data.
178
+
179
+
180
+ :param data: The data to run inference on
181
+ :param model: The model to run inference on (optional)
182
+ :return: The inputs, labels and outputs
183
+ """
184
+ labels = None
185
+ if model is None :
186
+ output = None
187
+
188
+ else :
189
+ inputs = {key : value .to (model .device ) for key , value in data .items ()}
190
+ output_vals = model (** inputs )
191
+ output = {
192
+ name : torch .squeeze (val ).detach ().to ("cpu" )
193
+ for name , val in output_vals .items ()
194
+ }
195
+ inputs = {key : value .to ("cpu" ) for key , value in data .items ()}
196
+ return inputs , labels , output
197
+
198
+
199
+ def run_inference_with_tuple_or_list_data (
200
+ data : Tuple [Any , Any ], model : Optional [torch .nn .Module ] = None
201
+ ) -> Tuple [torch .Tensor , Any , Optional [torch .Tensor ]]:
202
+ """
203
+ Run inference on a model by inferring the appropriate
204
+ inputs from the tuple input data.
205
+
206
+ :param inputs: The data to run inference on
207
+ :param model: The model to run inference on (optional)
208
+ :return: The inputs, labels and outputs
209
+ """
210
+ # assume that
211
+ inputs , labels = data
212
+ outputs = model (inputs ) if model else None
213
+ if isinstance (outputs , tuple ):
214
+ # outputs_ contains (logits, softmax)
215
+ outputs = outputs [0 ]
216
+ return inputs , labels , outputs
0 commit comments