Skip to content

Commit b82b49b

Browse files
dbogunowiczdsikka
andauthored
[DeepSparse Evaluation API] Perplexity (#1555)
* initial commit * Update src/deepsparse/evaluation/integrations/__init__.py * design ready, time to define additional features * split prep_for_generation operator * fix logits * update non-kv cache pipeline and tests * add tests to address edge cases * add condition to check of kv_cache full during prompt inference, add test to cover this case, revert debugging changes * fix typing * remove commented code * remove irrelevant condition * perplexity for non-kv cache pipelines works! * logic is working * ready for review * [DeepSparse Evaluation API] Perplexity eval support for `openai_humaneval`, `c4`, `wikitext2` (#1586) * fix tests 2 * initial commit * add return to a function * make script more robust --------- Co-authored-by: Dipika Sikka <[email protected]>
1 parent e0b4f36 commit b82b49b

File tree

9 files changed

+448
-9
lines changed

9 files changed

+448
-9
lines changed

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def _parse_requirements_file(file_path):
149149
"datasets<2.16",
150150
"accelerate<0.26",
151151
"seqeval",
152+
"evaluate",
152153
]
153154
_sentence_transformers_integration_deps = ["optimum-deepsparse"] + _torch_deps
154155

src/deepsparse/evaluation/evaluator.py

+3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from typing import List, Optional, Union
1717

1818
from deepsparse import Pipeline
19+
from deepsparse.evaluation.integrations.perplexity import ( # noqa
20+
integration_eval as integration_eval_perplexity,
21+
)
1922
from deepsparse.evaluation.registry import EvaluationRegistry
2023
from deepsparse.evaluation.results import Result
2124
from deepsparse.evaluation.utils import create_pipeline

src/deepsparse/evaluation/integrations/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ def try_import_lm_evaluation_harness(raise_error=True):
3131

3232
if try_import_lm_evaluation_harness(raise_error=False):
3333
from .lm_evaluation_harness import *
34+
from .perplexity import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
import os
17+
from collections import defaultdict
18+
from typing import Any, Dict, List, Optional, Union
19+
20+
import numpy
21+
from tqdm import tqdm
22+
23+
from datasets import load_dataset
24+
from deepsparse import Pipeline
25+
from deepsparse.evaluation.registry import EvaluationRegistry
26+
from deepsparse.evaluation.results import Dataset, Evaluation, Metric, Result
27+
from deepsparse.evaluation.utils import PERPLEXITY
28+
from deepsparse.transformers.metrics import Perplexity
29+
from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline
30+
from deepsparse.transformers.pipelines.text_generation.pipeline_no_kv_cache import (
31+
TextGenerationPipelineNoCache,
32+
)
33+
from deepsparse.transformers.utils.eval_helpers import (
34+
HumanEvalIteratorWrapper,
35+
process_concatenated_datasets,
36+
)
37+
38+
39+
"""
40+
Integration for the evaluation module
41+
that computes the perplexity of a model on a dataset
42+
"""
43+
_LOGGER = logging.getLogger(__name__)
44+
45+
46+
@EvaluationRegistry.register(name=PERPLEXITY)
47+
def integration_eval(
48+
pipeline: Pipeline,
49+
datasets: Union[List[str], str] = "openai_humaneval",
50+
batch_size: int = 1,
51+
limit: Optional[int] = None,
52+
accumulate: Optional[bool] = None,
53+
splits: Union[List[str], str, None] = "test",
54+
metrics: Union[List[str], str, None] = None,
55+
**kwargs,
56+
) -> Result:
57+
"""
58+
A function that computes the perplexity of a pipeline given a set
59+
of dataset names.
60+
61+
:param pipeline: the pipeline to evaluate. The assumed pipeline
62+
is a TextGenerationPipeline, either with or without the KV
63+
cache support
64+
:param datasets: the names of dataset(s) to evaluate on
65+
:param batch_size: the batch size to use for evaluation
66+
:param splits: the split of the dataset to evaluate on. Default is "test"
67+
:param metrics: the metrics to compute. Default is None
68+
:param limit: the number of batches to evaluate on. Default is None
69+
(evaluates on entire dataset)
70+
:param accumulate: whether to perplexity computation should
71+
accumulate negative log-likelihood over samples. Defaults to
72+
the default accumulate variable inferred from the dataset in
73+
`datasets`. If not None, it will override the inferred accumulate
74+
variable.
75+
:return: a Result object containing the raw and formatted results
76+
"""
77+
metrics = metrics or PERPLEXITY
78+
if metrics != PERPLEXITY:
79+
raise ValueError(f"Invalid metric {metrics} for perplexity evaluation")
80+
if splits is None:
81+
splits = "test"
82+
_LOGGER.info("Argument `splits` is None. Defaulting to `test` split.")
83+
datasets = datasets if isinstance(datasets, list) else [datasets]
84+
results_raw = defaultdict(str)
85+
for dataset_name in datasets:
86+
results_raw[dataset_name] = defaultdict()
87+
dataset, _accumulate = load_perplexity_dataset(
88+
dataset_name=dataset_name, splits=splits, pipeline=pipeline, **kwargs
89+
)
90+
if accumulate is None:
91+
accumulate = _accumulate
92+
else:
93+
_LOGGER.info(
94+
f"Argument `accumulate` set to {accumulate}. "
95+
"Overriding the inferred accumulate variable from the dataset."
96+
)
97+
98+
perplexity = run_perplexity(
99+
pipeline=pipeline,
100+
dataset=dataset,
101+
batch_size=batch_size,
102+
accumulate=accumulate,
103+
limit=limit,
104+
)
105+
106+
results_raw[dataset_name] = defaultdict()
107+
results_raw[dataset_name]["results"] = perplexity
108+
results_raw[dataset_name]["split"] = splits
109+
110+
results = Result(
111+
# omit storing raw results. they can potentially
112+
# contain numpy arrays that are not serializable.
113+
# all the information is stored in the formatted results
114+
raw=None,
115+
formatted=format_raw_results(results_raw),
116+
)
117+
118+
return results
119+
120+
121+
def run_perplexity(
122+
pipeline: Union[TextGenerationPipelineNoCache, TextGenerationPipeline],
123+
dataset: "Dataset",
124+
batch_size: int,
125+
accumulate: bool,
126+
limit: Optional[int] = None,
127+
) -> Dict[str, Any]:
128+
"""
129+
Compute the perplexity of a pipeline given a dataset.
130+
131+
:param pipeline: the pipeline to evaluate. The assumed pipeline
132+
is a TextGenerationPipeline, either with or without the KV
133+
cache support
134+
:param dataset: the dataset to evaluate on
135+
:param batch_size: the batch size to use for evaluation
136+
:param accumulate: whether to perplexity computation should
137+
accumulate negative log-likelihood over samples
138+
:param limit: the number of batches to evaluate on. Default is None
139+
(evaluates on entire dataset)
140+
141+
:return: a dictionary containing the perplexity results
142+
"""
143+
144+
perplexity = Perplexity(accumulate=accumulate)
145+
146+
batch = []
147+
for idx, sample in _enumerate_progress(
148+
dataset, max_steps=None if limit is None else limit * batch_size
149+
):
150+
151+
if limit is not None:
152+
# stop if we have reached the #limit
153+
# number of batches to be processed
154+
if idx >= limit * batch_size:
155+
break
156+
157+
batch.append(sample)
158+
159+
if len(batch) == batch_size:
160+
if isinstance(pipeline, TextGenerationPipelineNoCache):
161+
out = pipeline(
162+
prompt=batch,
163+
output_scores=True,
164+
include_prompt_logits=True,
165+
return_input_tokens=True,
166+
)
167+
else:
168+
out = pipeline(
169+
prompt=batch,
170+
output_scores=True,
171+
max_new_tokens=0,
172+
include_prompt_logits=True,
173+
return_input_tokens=True,
174+
)
175+
176+
for s in range(batch_size):
177+
# Need to remove tokens that were masked
178+
input_ids = out.input_tokens["input_ids"][s].flatten()
179+
attention_mask = out.input_tokens["attention_mask"][s].flatten()
180+
logits = out.generations[s].score
181+
if batch_size > 1 and isinstance(
182+
pipeline, TextGenerationPipelineNoCache
183+
):
184+
logits = logits[-attention_mask.sum() :, :]
185+
186+
logits = numpy.compress(attention_mask, logits, axis=0)[:-1, :]
187+
input_ids = numpy.compress(attention_mask, input_ids)[1:]
188+
189+
# Add predictions (logits) and targets (input_ids) to metric
190+
perplexity.add_batch(logits, input_ids)
191+
192+
batch.clear()
193+
194+
return perplexity.compute()
195+
196+
197+
def format_raw_results(results: Dict[str, Any]) -> List[Evaluation]:
198+
"""
199+
Format the raw perplexity results into a list of
200+
Evaluation objects.
201+
202+
:param results: the raw results from perplexity computation
203+
:return: the formatted results as a list of Evaluation objects
204+
"""
205+
formatted_results = []
206+
for dataset_name, dataset_result in results.items():
207+
metrics = []
208+
for metric_name, metric_value in dataset_result["results"].items():
209+
if isinstance(metric_value, numpy.ndarray):
210+
metric_value = metric_value.tolist()
211+
metric = Metric(name=metric_name, value=metric_value)
212+
metrics.append(metric)
213+
dataset = Dataset(type=None, name=dataset_name, split=dataset_result["split"])
214+
evaluation = Evaluation(
215+
task="perplexity",
216+
dataset=dataset,
217+
metrics=metrics,
218+
samples=None,
219+
)
220+
formatted_results.append(evaluation)
221+
return formatted_results
222+
223+
224+
def load_perplexity_dataset(
225+
dataset_name: str,
226+
splits: Union[List[str], str] = "test",
227+
pipeline: Optional[Pipeline] = None,
228+
**kwargs,
229+
):
230+
"""
231+
Function to load the dataset for perplexity computation.
232+
Eventually we want to load the dataset from the nm_utils
233+
234+
:param dataset_name: the name of the dataset to load
235+
:param splits: the splits to load from the dataset. Default is "test"
236+
:param pipeline: the pipeline to use for loading the dataset. The pipeline
237+
is used to infer the model path and sequence length to use for loading
238+
the dataset. This argument can be omitted if the appropriate kwargs
239+
are provided, or if the dataset does not require a process_concatenated_datasets
240+
function to load the dataset.
241+
:param kwargs: additional keyword arguments to pass to the dataset loading function
242+
:return: the dataset and whether to accumulate perplexity over samples
243+
"""
244+
if isinstance(splits, list):
245+
raise NotImplementedError("Evaluation on multiple splits not implemented")
246+
247+
if dataset_name == "openai_humaneval":
248+
dataset = load_dataset(dataset_name, split=splits)
249+
dataset = HumanEvalIteratorWrapper(dataset)
250+
accumulate = False
251+
elif dataset_name in {"wikitext2", "c4"}:
252+
# fetch max_sequence_length from pipeline if not provided
253+
max_sequence_length = kwargs.pop("max_sequence_length", None)
254+
if max_sequence_length is None and pipeline is not None:
255+
max_sequence_length = pipeline.sequence_length
256+
257+
# fetch model_path from pipeline if not provided
258+
model_path = kwargs.pop("model_path", None)
259+
if model_path is None and pipeline is not None:
260+
model_path = os.path.dirname(pipeline.model_path)
261+
262+
dataset = process_concatenated_datasets(
263+
dataset_name,
264+
model_path=model_path,
265+
max_sequence_length=max_sequence_length,
266+
split=splits,
267+
**kwargs,
268+
)
269+
accumulate = True
270+
else:
271+
raise NotImplementedError(f"Dataset {dataset_name} not implemented")
272+
273+
return dataset, accumulate
274+
275+
276+
def _enumerate_progress(dataset, max_steps):
277+
progress_bar = tqdm(dataset, total=max_steps) if max_steps else tqdm(dataset)
278+
return enumerate(progress_bar)

src/deepsparse/evaluation/results.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, List, Optional
15+
from typing import Any, List, Optional, Union
1616

1717
import yaml
1818
from pydantic import BaseModel, Field
@@ -32,7 +32,7 @@
3232

3333
class Metric(BaseModel):
3434
name: str = Field(description="Name of the metric")
35-
value: float = Field(description="Value of the metric")
35+
value: Union[float, List[float]] = Field(description="Value of the metric")
3636

3737

3838
class Dataset(BaseModel):

src/deepsparse/evaluation/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
"resolve_integration",
2828
]
2929
_LOGGER = logging.getLogger(__name__)
30+
3031
LM_EVALUATION_HARNESS = "lm-evaluation-harness"
32+
PERPLEXITY = "perplexity"
3133

3234

3335
def potentially_check_dependency_import(integration_name: str) -> bool:

src/deepsparse/transformers/metrics.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import numpy
2222

23-
from deepsparse.utils import numpy_log_softmax
23+
from deepsparse.utils.data import numpy_log_softmax
2424

2525

2626
__all__ = [

0 commit comments

Comments
 (0)