Skip to content

Commit 898f677

Browse files
authored
[DeepSparse Evaluation API] Perplexity eval support for openai_humaneval, c4, wikitext2 (#1586)
* fix tests 2 * initial commit * add return to a function * make script more robust
1 parent c17d8c6 commit 898f677

File tree

4 files changed

+122
-37
lines changed

4 files changed

+122
-37
lines changed

src/deepsparse/evaluation/integrations/perplexity.py

+63-9
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import logging
16+
import os
1617
from collections import defaultdict
1718
from typing import Any, Dict, List, Optional, Union
1819

@@ -29,6 +30,10 @@
2930
from deepsparse.transformers.pipelines.text_generation.pipeline_no_kv_cache import (
3031
TextGenerationPipelineNoCache,
3132
)
33+
from deepsparse.transformers.utils.eval_helpers import (
34+
HumanEvalIteratorWrapper,
35+
process_concatenated_datasets,
36+
)
3237

3338

3439
"""
@@ -41,9 +46,10 @@
4146
@EvaluationRegistry.register(name=PERPLEXITY)
4247
def integration_eval(
4348
pipeline: Pipeline,
44-
datasets: Union[List[str], str],
49+
datasets: Union[List[str], str] = "openai_humaneval",
4550
batch_size: int = 1,
4651
limit: Optional[int] = None,
52+
accumulate: Optional[bool] = None,
4753
splits: Union[List[str], str, None] = "test",
4854
metrics: Union[List[str], str, None] = None,
4955
**kwargs,
@@ -61,6 +67,11 @@ def integration_eval(
6167
:param metrics: the metrics to compute. Default is None
6268
:param limit: the number of batches to evaluate on. Default is None
6369
(evaluates on entire dataset)
70+
:param accumulate: whether to perplexity computation should
71+
accumulate negative log-likelihood over samples. Defaults to
72+
the default accumulate variable inferred from the dataset in
73+
`datasets`. If not None, it will override the inferred accumulate
74+
variable.
6475
:return: a Result object containing the raw and formatted results
6576
"""
6677
metrics = metrics or PERPLEXITY
@@ -69,12 +80,21 @@ def integration_eval(
6980
if splits is None:
7081
splits = "test"
7182
_LOGGER.info("Argument `splits` is None. Defaulting to `test` split.")
72-
7383
datasets = datasets if isinstance(datasets, list) else [datasets]
7484
results_raw = defaultdict(str)
7585
for dataset_name in datasets:
7686
results_raw[dataset_name] = defaultdict()
77-
dataset, accumulate = load_perplexity_dataset(dataset_name, splits)
87+
dataset, _accumulate = load_perplexity_dataset(
88+
dataset_name=dataset_name, splits=splits, pipeline=pipeline, **kwargs
89+
)
90+
if accumulate is None:
91+
accumulate = _accumulate
92+
else:
93+
_LOGGER.info(
94+
f"Argument `accumulate` set to {accumulate}. "
95+
"Overriding the inferred accumulate variable from the dataset."
96+
)
97+
7898
perplexity = run_perplexity(
7999
pipeline=pipeline,
80100
dataset=dataset,
@@ -100,7 +120,7 @@ def integration_eval(
100120

101121
def run_perplexity(
102122
pipeline: Union[TextGenerationPipelineNoCache, TextGenerationPipeline],
103-
dataset: Any, # TODO: Edit, once we agree on the dataset registry
123+
dataset: "Dataset",
104124
batch_size: int,
105125
accumulate: bool,
106126
limit: Optional[int] = None,
@@ -127,8 +147,6 @@ def run_perplexity(
127147
for idx, sample in _enumerate_progress(
128148
dataset, max_steps=None if limit is None else limit * batch_size
129149
):
130-
# TODO: To remove when we have support for more datasets
131-
sample = sample["prompt"] + sample["canonical_solution"]
132150

133151
if limit is not None:
134152
# stop if we have reached the #limit
@@ -203,21 +221,57 @@ def format_raw_results(results: Dict[str, Any]) -> List[Evaluation]:
203221
return formatted_results
204222

205223

206-
def load_perplexity_dataset(dataset_name: str, splits: Union[List[str], str] = "test"):
224+
def load_perplexity_dataset(
225+
dataset_name: str,
226+
splits: Union[List[str], str] = "test",
227+
pipeline: Optional[Pipeline] = None,
228+
**kwargs,
229+
):
207230
"""
208-
Dummy function to load the dataset for perplexity computation.
231+
Function to load the dataset for perplexity computation.
209232
Eventually we want to load the dataset from the nm_utils
233+
234+
:param dataset_name: the name of the dataset to load
235+
:param splits: the splits to load from the dataset. Default is "test"
236+
:param pipeline: the pipeline to use for loading the dataset. The pipeline
237+
is used to infer the model path and sequence length to use for loading
238+
the dataset. This argument can be omitted if the appropriate kwargs
239+
are provided, or if the dataset does not require a process_concatenated_datasets
240+
function to load the dataset.
241+
:param kwargs: additional keyword arguments to pass to the dataset loading function
242+
:return: the dataset and whether to accumulate perplexity over samples
210243
"""
211244
if isinstance(splits, list):
212245
raise NotImplementedError("Evaluation on multiple splits not implemented")
213246

214247
if dataset_name == "openai_humaneval":
215248
dataset = load_dataset(dataset_name, split=splits)
249+
dataset = HumanEvalIteratorWrapper(dataset)
216250
accumulate = False
217-
return dataset, accumulate
251+
elif dataset_name in {"wikitext2", "c4"}:
252+
# fetch max_sequence_length from pipeline if not provided
253+
max_sequence_length = kwargs.pop("max_sequence_length", None)
254+
if max_sequence_length is None and pipeline is not None:
255+
max_sequence_length = pipeline.sequence_length
256+
257+
# fetch model_path from pipeline if not provided
258+
model_path = kwargs.pop("model_path", None)
259+
if model_path is None and pipeline is not None:
260+
model_path = os.path.dirname(pipeline.model_path)
261+
262+
dataset = process_concatenated_datasets(
263+
dataset_name,
264+
model_path=model_path,
265+
max_sequence_length=max_sequence_length,
266+
split=splits,
267+
**kwargs,
268+
)
269+
accumulate = True
218270
else:
219271
raise NotImplementedError(f"Dataset {dataset_name} not implemented")
220272

273+
return dataset, accumulate
274+
221275

222276
def _enumerate_progress(dataset, max_steps):
223277
progress_bar = tqdm(dataset, total=max_steps) if max_steps else tqdm(dataset)

src/deepsparse/transformers/metrics.py

-1
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,5 @@ def _cross_entropy(
275275
neg_log_likelihoods = neg_log_likelihoods.mean(axis=-1)
276276
elif reduction == "sum":
277277
neg_log_likelihoods = neg_log_likelihoods.sum(axis=-1)
278-
print(neg_log_likelihoods)
279278

280279
return neg_log_likelihoods

src/deepsparse/transformers/utils/eval_helpers.py

+28-6
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List, Mapping, Union
15+
from typing import List, Union
1616

1717
import numpy
1818
from transformers import AutoTokenizer, PreTrainedTokenizerFast
@@ -27,7 +27,8 @@ def process_concatenated_datasets(
2727
dataset_name: str,
2828
model_path: str,
2929
max_sequence_length: int,
30-
kwargs: Mapping,
30+
split: str = "test",
31+
**kwargs,
3132
) -> list:
3233
"""
3334
Concatenate text datasets and split them into chunks text that, after
@@ -38,6 +39,8 @@ def process_concatenated_datasets(
3839
Options: "wikitext2" or "c4".
3940
model_path (str): The path to a pretrained transformer model for tokenization.
4041
max_sequence_length (int): The maximum number of tokens in each sequence.
42+
split (str, optional): The split of the dataset to use.
43+
Default is "test".
4144
kwargs (mapping): Additional keyword arguments.
4245
- eos (str, optional): The end-of-sentence token.
4346
Default is "\n\n" for wikitext2 and "" for c4.
@@ -65,27 +68,27 @@ def process_concatenated_datasets(
6568
eos = kwargs.get("eos", "\n\n")
6669
bos = kwargs.get("bos", "")
6770

68-
raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
71+
raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
6972
raw_text = raw_dataset["text"]
7073
elif dataset_name == "c4":
7174
eos = kwargs.get("eos", "<|endoftext|>")
7275
bos = kwargs.get("bos", "")
7376
raw_samples = kwargs.get("raw_samples", None)
74-
data_file = kwargs.get("data_file", 0)
77+
data_file = kwargs.get("data_file", None)
7578
if data_file is not None:
7679
raw_dataset = load_dataset(
7780
"allenai/c4",
7881
"allenai--c4",
7982
data_files={
8083
"validation": f"en/c4-validation.{data_file:05d}-of-00008.json.gz"
8184
},
82-
split="validation",
85+
split=split,
8386
)
8487
else:
8588
raw_dataset = load_dataset(
8689
"allenai/c4",
8790
"allenai--c4",
88-
split="validation",
91+
split=split,
8992
)
9093
if raw_samples is not None:
9194
raw_dataset = raw_dataset[:raw_samples]
@@ -181,3 +184,22 @@ def _split_text_by_tokens(
181184
)
182185

183186
return split_text
187+
188+
189+
class HumanEvalIteratorWrapper:
190+
"""
191+
Wrapper around the `openai_humaneval` dataset,
192+
that joins the prompt and the canonical solution
193+
into a single string during iteration.
194+
"""
195+
196+
def __init__(self, dataset):
197+
self.iterator = iter(dataset)
198+
199+
def __iter__(self):
200+
return self
201+
202+
def __next__(self):
203+
# Get the next sample from the original iterator
204+
sample = next(self.iterator)
205+
return sample["prompt"] + sample["canonical_solution"]

tests/deepsparse/evaluation/integrations/test_perplexity.py

+31-21
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from copy import copy
16+
1517
import numpy as np
1618

1719
import pytest
@@ -37,26 +39,28 @@ def model_id():
3739
"datasets",
3840
[
3941
"openai_humaneval",
40-
# TODO: add more datasets
41-
# "c4",
42-
# "wikitext2",
42+
"wikitext2",
4343
],
4444
)
45-
@pytest.mark.parametrize(
46-
"batch_size",
47-
[1, 3],
48-
)
49-
class TestLMEvaluationHarness:
50-
limit = 4
45+
@pytest.mark.parametrize("batch_size", [1, 2])
46+
class TestPerplexity:
47+
limit = 2
5148

5249
def test_perplexity_ground_truth_equal_pipeline(
5350
self, model_path, model_id, datasets, batch_size
5451
):
52+
# setting max_sequence_length to 16 to speed up the test
53+
kwargs_ground_truth = (
54+
dict(max_sequence_length=16) if datasets in {"c4", "wikitext2"} else {}
55+
)
56+
kwargs = copy(kwargs_ground_truth)
57+
5558
result_gt = self._get_ground_truth(
5659
datasets=datasets,
5760
batch_size=batch_size,
5861
limit=self.limit,
5962
model_id=model_id,
63+
kwargs=kwargs_ground_truth,
6064
)
6165

6266
result = integration_eval(
@@ -67,26 +71,34 @@ def test_perplexity_ground_truth_equal_pipeline(
6771
datasets=datasets,
6872
batch_size=batch_size,
6973
limit=self.limit,
74+
# we are setting accumulate=False to compare
75+
# with the torch ground truth apples to apples
76+
accumulate=False,
77+
**kwargs,
7078
)
7179
perplexities = result.formatted[0].metrics[0].value
7280
perplexities_gt = result_gt["perplexities"]
73-
# TODO: This seemingly big error is due to the fact that
74-
# small (1e-2) differences in neg log likelihood get
75-
# amplified when computing perplexity
76-
# (when applying exp function)
7781
assert np.allclose(perplexities, perplexities_gt, rtol=0.1)
7882

7983
def test_perplexity_kv_cache_pipeline_equal_no_kv_cache_pipeline(
80-
self, model_path, datasets, batch_size
84+
self, model_path, model_id, datasets, batch_size
8185
):
86+
87+
kwargs_ground_truth = (
88+
dict(max_sequence_length=16) if datasets in {"c4", "wikitext2"} else {}
89+
)
90+
kwargs = copy(kwargs_ground_truth)
91+
8292
result_kv_cache = integration_eval(
8393
pipeline=TextGenerationPipeline(
8494
model_path="hf:mgoin/TinyStories-1M-deepsparse",
8595
engine_type="onnxruntime",
8696
),
8797
datasets=datasets,
98+
model_path=model_id,
8899
batch_size=batch_size,
89100
limit=self.limit,
101+
**kwargs,
90102
)
91103

92104
result_non_kv_cache = integration_eval(
@@ -98,25 +110,23 @@ def test_perplexity_kv_cache_pipeline_equal_no_kv_cache_pipeline(
98110
datasets=datasets,
99111
batch_size=batch_size,
100112
limit=self.limit,
113+
**kwargs,
101114
)
102115

103116
perplexities_kv_cache = result_kv_cache.formatted[0].metrics[0].value
104117
perplexities_non_kv_cache = result_non_kv_cache.formatted[0].metrics[0].value
105-
# TODO: This seemingly big error is due to the fact that
106-
# small (1e-2) differences in neg log likelihood get
107-
# amplified when computing perplexity
108-
# (when applying exp function).
109118
np.allclose(perplexities_kv_cache, perplexities_non_kv_cache, rtol=0.1)
110119

111120
@staticmethod
112-
def _get_ground_truth(datasets, batch_size, limit, model_id):
121+
def _get_ground_truth(datasets, batch_size, limit, model_id, kwargs={}):
113122
perplexity = load("perplexity", module_type="metric")
114-
dataset, *_ = load_perplexity_dataset(dataset_name=datasets, splits="test")
123+
kwargs["model_path"] = model_id
124+
dataset, *_ = load_perplexity_dataset(dataset_name=datasets, **kwargs)
115125
predictions = []
116126
for i, sample in enumerate(dataset):
117127
if i == batch_size * limit:
118128
break
119-
predictions.append(sample["prompt"] + sample["canonical_solution"])
129+
predictions.append(sample)
120130
return perplexity.compute(
121131
predictions=predictions, add_start_token=False, model_id=model_id
122132
)

0 commit comments

Comments
 (0)