diff --git a/src/deepsparse/evaluation/integrations/perplexity.py b/src/deepsparse/evaluation/integrations/perplexity.py index 8a759261ac..a9a3f3d8a3 100644 --- a/src/deepsparse/evaluation/integrations/perplexity.py +++ b/src/deepsparse/evaluation/integrations/perplexity.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import os from collections import defaultdict from typing import Any, Dict, List, Optional, Union @@ -29,6 +30,10 @@ from deepsparse.transformers.pipelines.text_generation.pipeline_no_kv_cache import ( TextGenerationPipelineNoCache, ) +from deepsparse.transformers.utils.eval_helpers import ( + HumanEvalIteratorWrapper, + process_concatenated_datasets, +) """ @@ -41,9 +46,10 @@ @EvaluationRegistry.register(name=PERPLEXITY) def integration_eval( pipeline: Pipeline, - datasets: Union[List[str], str], + datasets: Union[List[str], str] = "openai_humaneval", batch_size: int = 1, limit: Optional[int] = None, + accumulate: Optional[bool] = None, splits: Union[List[str], str, None] = "test", metrics: Union[List[str], str, None] = None, **kwargs, @@ -61,6 +67,11 @@ def integration_eval( :param metrics: the metrics to compute. Default is None :param limit: the number of batches to evaluate on. Default is None (evaluates on entire dataset) + :param accumulate: whether to perplexity computation should + accumulate negative log-likelihood over samples. Defaults to + the default accumulate variable inferred from the dataset in + `datasets`. If not None, it will override the inferred accumulate + variable. :return: a Result object containing the raw and formatted results """ metrics = metrics or PERPLEXITY @@ -69,12 +80,21 @@ def integration_eval( if splits is None: splits = "test" _LOGGER.info("Argument `splits` is None. Defaulting to `test` split.") - datasets = datasets if isinstance(datasets, list) else [datasets] results_raw = defaultdict(str) for dataset_name in datasets: results_raw[dataset_name] = defaultdict() - dataset, accumulate = load_perplexity_dataset(dataset_name, splits) + dataset, _accumulate = load_perplexity_dataset( + dataset_name=dataset_name, splits=splits, pipeline=pipeline, **kwargs + ) + if accumulate is None: + accumulate = _accumulate + else: + _LOGGER.info( + f"Argument `accumulate` set to {accumulate}. " + "Overriding the inferred accumulate variable from the dataset." + ) + perplexity = run_perplexity( pipeline=pipeline, dataset=dataset, @@ -100,7 +120,7 @@ def integration_eval( def run_perplexity( pipeline: Union[TextGenerationPipelineNoCache, TextGenerationPipeline], - dataset: Any, # TODO: Edit, once we agree on the dataset registry + dataset: "Dataset", batch_size: int, accumulate: bool, limit: Optional[int] = None, @@ -127,8 +147,6 @@ def run_perplexity( for idx, sample in _enumerate_progress( dataset, max_steps=None if limit is None else limit * batch_size ): - # TODO: To remove when we have support for more datasets - sample = sample["prompt"] + sample["canonical_solution"] if limit is not None: # stop if we have reached the #limit @@ -203,21 +221,57 @@ def format_raw_results(results: Dict[str, Any]) -> List[Evaluation]: return formatted_results -def load_perplexity_dataset(dataset_name: str, splits: Union[List[str], str] = "test"): +def load_perplexity_dataset( + dataset_name: str, + splits: Union[List[str], str] = "test", + pipeline: Optional[Pipeline] = None, + **kwargs, +): """ - Dummy function to load the dataset for perplexity computation. + Function to load the dataset for perplexity computation. Eventually we want to load the dataset from the nm_utils + + :param dataset_name: the name of the dataset to load + :param splits: the splits to load from the dataset. Default is "test" + :param pipeline: the pipeline to use for loading the dataset. The pipeline + is used to infer the model path and sequence length to use for loading + the dataset. This argument can be omitted if the appropriate kwargs + are provided, or if the dataset does not require a process_concatenated_datasets + function to load the dataset. + :param kwargs: additional keyword arguments to pass to the dataset loading function + :return: the dataset and whether to accumulate perplexity over samples """ if isinstance(splits, list): raise NotImplementedError("Evaluation on multiple splits not implemented") if dataset_name == "openai_humaneval": dataset = load_dataset(dataset_name, split=splits) + dataset = HumanEvalIteratorWrapper(dataset) accumulate = False - return dataset, accumulate + elif dataset_name in {"wikitext2", "c4"}: + # fetch max_sequence_length from pipeline if not provided + max_sequence_length = kwargs.pop("max_sequence_length", None) + if max_sequence_length is None and pipeline is not None: + max_sequence_length = pipeline.sequence_length + + # fetch model_path from pipeline if not provided + model_path = kwargs.pop("model_path", None) + if model_path is None and pipeline is not None: + model_path = os.path.dirname(pipeline.model_path) + + dataset = process_concatenated_datasets( + dataset_name, + model_path=model_path, + max_sequence_length=max_sequence_length, + split=splits, + **kwargs, + ) + accumulate = True else: raise NotImplementedError(f"Dataset {dataset_name} not implemented") + return dataset, accumulate + def _enumerate_progress(dataset, max_steps): progress_bar = tqdm(dataset, total=max_steps) if max_steps else tqdm(dataset) diff --git a/src/deepsparse/transformers/metrics.py b/src/deepsparse/transformers/metrics.py index 8e23af2fcf..0e7c24c8b6 100644 --- a/src/deepsparse/transformers/metrics.py +++ b/src/deepsparse/transformers/metrics.py @@ -275,6 +275,5 @@ def _cross_entropy( neg_log_likelihoods = neg_log_likelihoods.mean(axis=-1) elif reduction == "sum": neg_log_likelihoods = neg_log_likelihoods.sum(axis=-1) - print(neg_log_likelihoods) return neg_log_likelihoods diff --git a/src/deepsparse/transformers/utils/eval_helpers.py b/src/deepsparse/transformers/utils/eval_helpers.py index 4c0e68b9de..012520b9b5 100644 --- a/src/deepsparse/transformers/utils/eval_helpers.py +++ b/src/deepsparse/transformers/utils/eval_helpers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Mapping, Union +from typing import List, Union import numpy from transformers import AutoTokenizer, PreTrainedTokenizerFast @@ -27,7 +27,8 @@ def process_concatenated_datasets( dataset_name: str, model_path: str, max_sequence_length: int, - kwargs: Mapping, + split: str = "test", + **kwargs, ) -> list: """ Concatenate text datasets and split them into chunks text that, after @@ -38,6 +39,8 @@ def process_concatenated_datasets( Options: "wikitext2" or "c4". model_path (str): The path to a pretrained transformer model for tokenization. max_sequence_length (int): The maximum number of tokens in each sequence. + split (str, optional): The split of the dataset to use. + Default is "test". kwargs (mapping): Additional keyword arguments. - eos (str, optional): The end-of-sentence token. Default is "\n\n" for wikitext2 and "" for c4. @@ -65,13 +68,13 @@ def process_concatenated_datasets( eos = kwargs.get("eos", "\n\n") bos = kwargs.get("bos", "") - raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split=split) raw_text = raw_dataset["text"] elif dataset_name == "c4": eos = kwargs.get("eos", "<|endoftext|>") bos = kwargs.get("bos", "") raw_samples = kwargs.get("raw_samples", None) - data_file = kwargs.get("data_file", 0) + data_file = kwargs.get("data_file", None) if data_file is not None: raw_dataset = load_dataset( "allenai/c4", @@ -79,13 +82,13 @@ def process_concatenated_datasets( data_files={ "validation": f"en/c4-validation.{data_file:05d}-of-00008.json.gz" }, - split="validation", + split=split, ) else: raw_dataset = load_dataset( "allenai/c4", "allenai--c4", - split="validation", + split=split, ) if raw_samples is not None: raw_dataset = raw_dataset[:raw_samples] @@ -181,3 +184,22 @@ def _split_text_by_tokens( ) return split_text + + +class HumanEvalIteratorWrapper: + """ + Wrapper around the `openai_humaneval` dataset, + that joins the prompt and the canonical solution + into a single string during iteration. + """ + + def __init__(self, dataset): + self.iterator = iter(dataset) + + def __iter__(self): + return self + + def __next__(self): + # Get the next sample from the original iterator + sample = next(self.iterator) + return sample["prompt"] + sample["canonical_solution"] diff --git a/tests/deepsparse/evaluation/integrations/test_perplexity.py b/tests/deepsparse/evaluation/integrations/test_perplexity.py index 04fec4458c..b156e5b9a4 100644 --- a/tests/deepsparse/evaluation/integrations/test_perplexity.py +++ b/tests/deepsparse/evaluation/integrations/test_perplexity.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import copy + import numpy as np import pytest @@ -37,26 +39,28 @@ def model_id(): "datasets", [ "openai_humaneval", - # TODO: add more datasets - # "c4", - # "wikitext2", + "wikitext2", ], ) -@pytest.mark.parametrize( - "batch_size", - [1, 3], -) -class TestLMEvaluationHarness: - limit = 4 +@pytest.mark.parametrize("batch_size", [1, 2]) +class TestPerplexity: + limit = 2 def test_perplexity_ground_truth_equal_pipeline( self, model_path, model_id, datasets, batch_size ): + # setting max_sequence_length to 16 to speed up the test + kwargs_ground_truth = ( + dict(max_sequence_length=16) if datasets in {"c4", "wikitext2"} else {} + ) + kwargs = copy(kwargs_ground_truth) + result_gt = self._get_ground_truth( datasets=datasets, batch_size=batch_size, limit=self.limit, model_id=model_id, + kwargs=kwargs_ground_truth, ) result = integration_eval( @@ -67,26 +71,34 @@ def test_perplexity_ground_truth_equal_pipeline( datasets=datasets, batch_size=batch_size, limit=self.limit, + # we are setting accumulate=False to compare + # with the torch ground truth apples to apples + accumulate=False, + **kwargs, ) perplexities = result.formatted[0].metrics[0].value perplexities_gt = result_gt["perplexities"] - # TODO: This seemingly big error is due to the fact that - # small (1e-2) differences in neg log likelihood get - # amplified when computing perplexity - # (when applying exp function) assert np.allclose(perplexities, perplexities_gt, rtol=0.1) def test_perplexity_kv_cache_pipeline_equal_no_kv_cache_pipeline( - self, model_path, datasets, batch_size + self, model_path, model_id, datasets, batch_size ): + + kwargs_ground_truth = ( + dict(max_sequence_length=16) if datasets in {"c4", "wikitext2"} else {} + ) + kwargs = copy(kwargs_ground_truth) + result_kv_cache = integration_eval( pipeline=TextGenerationPipeline( model_path="hf:mgoin/TinyStories-1M-deepsparse", engine_type="onnxruntime", ), datasets=datasets, + model_path=model_id, batch_size=batch_size, limit=self.limit, + **kwargs, ) result_non_kv_cache = integration_eval( @@ -98,25 +110,23 @@ def test_perplexity_kv_cache_pipeline_equal_no_kv_cache_pipeline( datasets=datasets, batch_size=batch_size, limit=self.limit, + **kwargs, ) perplexities_kv_cache = result_kv_cache.formatted[0].metrics[0].value perplexities_non_kv_cache = result_non_kv_cache.formatted[0].metrics[0].value - # TODO: This seemingly big error is due to the fact that - # small (1e-2) differences in neg log likelihood get - # amplified when computing perplexity - # (when applying exp function). np.allclose(perplexities_kv_cache, perplexities_non_kv_cache, rtol=0.1) @staticmethod - def _get_ground_truth(datasets, batch_size, limit, model_id): + def _get_ground_truth(datasets, batch_size, limit, model_id, kwargs={}): perplexity = load("perplexity", module_type="metric") - dataset, *_ = load_perplexity_dataset(dataset_name=datasets, splits="test") + kwargs["model_path"] = model_id + dataset, *_ = load_perplexity_dataset(dataset_name=datasets, **kwargs) predictions = [] for i, sample in enumerate(dataset): if i == batch_size * limit: break - predictions.append(sample["prompt"] + sample["canonical_solution"]) + predictions.append(sample) return perplexity.compute( predictions=predictions, add_start_token=False, model_id=model_id )