13
13
# limitations under the License.
14
14
15
15
import logging
16
+ import os
16
17
from collections import defaultdict
17
18
from typing import Any , Dict , List , Optional , Union
18
19
29
30
from deepsparse .transformers .pipelines .text_generation .pipeline_no_kv_cache import (
30
31
TextGenerationPipelineNoCache ,
31
32
)
33
+ from deepsparse .transformers .utils .eval_helpers import (
34
+ HumanEvalIteratorWrapper ,
35
+ process_concatenated_datasets ,
36
+ )
32
37
33
38
34
39
"""
41
46
@EvaluationRegistry .register (name = PERPLEXITY )
42
47
def integration_eval (
43
48
pipeline : Pipeline ,
44
- datasets : Union [List [str ], str ],
49
+ datasets : Union [List [str ], str ] = "openai_humaneval" ,
45
50
batch_size : int = 1 ,
46
51
limit : Optional [int ] = None ,
52
+ accumulate : Optional [bool ] = None ,
47
53
splits : Union [List [str ], str , None ] = "test" ,
48
54
metrics : Union [List [str ], str , None ] = None ,
49
55
** kwargs ,
@@ -61,6 +67,11 @@ def integration_eval(
61
67
:param metrics: the metrics to compute. Default is None
62
68
:param limit: the number of batches to evaluate on. Default is None
63
69
(evaluates on entire dataset)
70
+ :param accumulate: whether to perplexity computation should
71
+ accumulate negative log-likelihood over samples. Defaults to
72
+ the default accumulate variable inferred from the dataset in
73
+ `datasets`. If not None, it will override the inferred accumulate
74
+ variable.
64
75
:return: a Result object containing the raw and formatted results
65
76
"""
66
77
metrics = metrics or PERPLEXITY
@@ -69,12 +80,21 @@ def integration_eval(
69
80
if splits is None :
70
81
splits = "test"
71
82
_LOGGER .info ("Argument `splits` is None. Defaulting to `test` split." )
72
-
73
83
datasets = datasets if isinstance (datasets , list ) else [datasets ]
74
84
results_raw = defaultdict (str )
75
85
for dataset_name in datasets :
76
86
results_raw [dataset_name ] = defaultdict ()
77
- dataset , accumulate = load_perplexity_dataset (dataset_name , splits )
87
+ dataset , _accumulate = load_perplexity_dataset (
88
+ dataset_name = dataset_name , splits = splits , pipeline = pipeline , ** kwargs
89
+ )
90
+ if accumulate is None :
91
+ accumulate = _accumulate
92
+ else :
93
+ _LOGGER .info (
94
+ f"Argument `accumulate` set to { accumulate } . "
95
+ "Overriding the inferred accumulate variable from the dataset."
96
+ )
97
+
78
98
perplexity = run_perplexity (
79
99
pipeline = pipeline ,
80
100
dataset = dataset ,
@@ -100,7 +120,7 @@ def integration_eval(
100
120
101
121
def run_perplexity (
102
122
pipeline : Union [TextGenerationPipelineNoCache , TextGenerationPipeline ],
103
- dataset : Any , # TODO: Edit, once we agree on the dataset registry
123
+ dataset : "Dataset" ,
104
124
batch_size : int ,
105
125
accumulate : bool ,
106
126
limit : Optional [int ] = None ,
@@ -127,8 +147,6 @@ def run_perplexity(
127
147
for idx , sample in _enumerate_progress (
128
148
dataset , max_steps = None if limit is None else limit * batch_size
129
149
):
130
- # TODO: To remove when we have support for more datasets
131
- sample = sample ["prompt" ] + sample ["canonical_solution" ]
132
150
133
151
if limit is not None :
134
152
# stop if we have reached the #limit
@@ -203,21 +221,57 @@ def format_raw_results(results: Dict[str, Any]) -> List[Evaluation]:
203
221
return formatted_results
204
222
205
223
206
- def load_perplexity_dataset (dataset_name : str , splits : Union [List [str ], str ] = "test" ):
224
+ def load_perplexity_dataset (
225
+ dataset_name : str ,
226
+ splits : Union [List [str ], str ] = "test" ,
227
+ pipeline : Optional [Pipeline ] = None ,
228
+ ** kwargs ,
229
+ ):
207
230
"""
208
- Dummy function to load the dataset for perplexity computation.
231
+ Function to load the dataset for perplexity computation.
209
232
Eventually we want to load the dataset from the nm_utils
233
+
234
+ :param dataset_name: the name of the dataset to load
235
+ :param splits: the splits to load from the dataset. Default is "test"
236
+ :param pipeline: the pipeline to use for loading the dataset. The pipeline
237
+ is used to infer the model path and sequence length to use for loading
238
+ the dataset. This argument can be omitted if the appropriate kwargs
239
+ are provided, or if the dataset does not require a process_concatenated_datasets
240
+ function to load the dataset.
241
+ :param kwargs: additional keyword arguments to pass to the dataset loading function
242
+ :return: the dataset and whether to accumulate perplexity over samples
210
243
"""
211
244
if isinstance (splits , list ):
212
245
raise NotImplementedError ("Evaluation on multiple splits not implemented" )
213
246
214
247
if dataset_name == "openai_humaneval" :
215
248
dataset = load_dataset (dataset_name , split = splits )
249
+ dataset = HumanEvalIteratorWrapper (dataset )
216
250
accumulate = False
217
- return dataset , accumulate
251
+ elif dataset_name in {"wikitext2" , "c4" }:
252
+ # fetch max_sequence_length from pipeline if not provided
253
+ max_sequence_length = kwargs .pop ("max_sequence_length" , None )
254
+ if max_sequence_length is None and pipeline is not None :
255
+ max_sequence_length = pipeline .sequence_length
256
+
257
+ # fetch model_path from pipeline if not provided
258
+ model_path = kwargs .pop ("model_path" , None )
259
+ if model_path is None and pipeline is not None :
260
+ model_path = os .path .dirname (pipeline .model_path )
261
+
262
+ dataset = process_concatenated_datasets (
263
+ dataset_name ,
264
+ model_path = model_path ,
265
+ max_sequence_length = max_sequence_length ,
266
+ split = splits ,
267
+ ** kwargs ,
268
+ )
269
+ accumulate = True
218
270
else :
219
271
raise NotImplementedError (f"Dataset { dataset_name } not implemented" )
220
272
273
+ return dataset , accumulate
274
+
221
275
222
276
def _enumerate_progress (dataset , max_steps ):
223
277
progress_bar = tqdm (dataset , total = max_steps ) if max_steps else tqdm (dataset )
0 commit comments