|
17 | 17 | """
|
18 | 18 |
|
19 | 19 |
|
20 |
| -from typing import Dict, Optional |
| 20 | +from typing import Any, Dict, List, Optional |
21 | 21 |
|
22 | 22 | import numpy
|
| 23 | +from tqdm import tqdm |
23 | 24 |
|
| 25 | +import torch |
| 26 | +from deepsparse import Pipeline |
| 27 | +from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline |
24 | 28 | from sklearn.metrics import precision_recall_fscore_support
|
25 | 29 |
|
26 | 30 |
|
27 | 31 | __all__ = [
|
28 | 32 | "PrecisionRecallF1",
|
| 33 | + "Perplexity", |
29 | 34 | ]
|
30 | 35 |
|
31 | 36 |
|
| 37 | +class Perplexity: |
| 38 | + def __init__(self, pipeline: Pipeline, batch_size: int = 16): |
| 39 | + """ |
| 40 | + Given the pipeline, compute the perplexity of the model |
| 41 | + on the given text input. |
| 42 | +
|
| 43 | + Code adapted from: |
| 44 | + https://huggingface.co/spaces/evaluate-metric/perplexity/blob/main/perplexity.py # noqa: E501 |
| 45 | +
|
| 46 | + :param pipeline: The pipeline to use for text generation |
| 47 | + :param batch_size: The batch size to split the input text into |
| 48 | + non-overlapping batches |
| 49 | + """ |
| 50 | + if not isinstance(pipeline, TextGenerationPipeline): |
| 51 | + raise ValueError( |
| 52 | + "Perplexity can only be computed for text generation pipelines" |
| 53 | + ) |
| 54 | + self._pipeline = pipeline |
| 55 | + self._batch_size = batch_size |
| 56 | + self._sequence_length = pipeline.sequence_length |
| 57 | + self._loss_fct = torch.nn.CrossEntropyLoss(reduction="none") |
| 58 | + |
| 59 | + self.perplexities = [] |
| 60 | + |
| 61 | + def add_batch(self, predictions: List[str]): |
| 62 | + """ |
| 63 | + Run the model on the given input sequences and compute the perplexity. |
| 64 | + The resulting perplexity is appended to the list of perplexities. |
| 65 | +
|
| 66 | + :param predictions: The predictions to compute perplexity on |
| 67 | + """ |
| 68 | + # tokenize the input text |
| 69 | + encodings = self._pipeline.tokenizer( |
| 70 | + predictions, |
| 71 | + return_attention_mask=True, |
| 72 | + max_length=self._sequence_length, |
| 73 | + truncation=True, |
| 74 | + padding="max_length", |
| 75 | + ) |
| 76 | + |
| 77 | + encoded_texts = encodings["input_ids"] |
| 78 | + attention_masks = encodings["attention_mask"] |
| 79 | + |
| 80 | + for start_index in tqdm(range(0, len(encoded_texts), self._batch_size)): |
| 81 | + end_index = min(start_index + self._batch_size, len(encoded_texts)) |
| 82 | + encoded_batch = encoded_texts[start_index:end_index] |
| 83 | + attention_mask = attention_masks[start_index:end_index] |
| 84 | + |
| 85 | + out = self._pipeline( |
| 86 | + sequences=predictions, return_logits=True, truncate=True |
| 87 | + ) |
| 88 | + logits = out.logits |
| 89 | + |
| 90 | + labels = encoded_batch |
| 91 | + labels = numpy.stack(labels) |
| 92 | + attention_mask = numpy.stack(attention_mask) |
| 93 | + |
| 94 | + # because the tokenizer is left padded, we need to move the meaningful |
| 95 | + # part of the logits and labels to the right |
| 96 | + num_padded_entries = attention_mask.sum(axis=1) |
| 97 | + |
| 98 | + # shift the values at num_paddings to the top of the array using roll |
| 99 | + for i, num_padded in enumerate(num_padded_entries): |
| 100 | + logits[i] = numpy.roll(logits[i], num_padded, axis=0) |
| 101 | + labels[i] = numpy.roll(labels[i], num_padded, axis=0) |
| 102 | + attention_mask[i] = numpy.roll(attention_mask[i], num_padded, axis=0) |
| 103 | + |
| 104 | + # shift logits and labels create the input and target for the loss function |
| 105 | + shift_logits = logits[:, :-1, :] |
| 106 | + shift_labels = labels[:, 1:] |
| 107 | + shift_attention_mask_batch = attention_mask[:, 1:] |
| 108 | + |
| 109 | + # compute perplexity for this batch |
| 110 | + perplexity_batch = torch.exp( |
| 111 | + ( |
| 112 | + self._loss_fct( |
| 113 | + torch.tensor(shift_logits.transpose(0, 2, 1)), |
| 114 | + torch.tensor(shift_labels), |
| 115 | + ) |
| 116 | + * torch.tensor(shift_attention_mask_batch) |
| 117 | + ).sum(1) |
| 118 | + / torch.tensor(shift_attention_mask_batch).sum(1) |
| 119 | + ) |
| 120 | + self.perplexities.extend(perplexity_batch.numpy().tolist()) |
| 121 | + |
| 122 | + def compute(self) -> Dict[str, Any]: |
| 123 | + """ |
| 124 | + :return: A dictionary containing the mean perplexity |
| 125 | + and the list of perplexities |
| 126 | + """ |
| 127 | + return { |
| 128 | + "mean_perplexity": numpy.mean(self.perplexities), |
| 129 | + "perplexities": self.perplexities, |
| 130 | + } |
| 131 | + |
| 132 | + |
32 | 133 | class PrecisionRecallF1:
|
33 | 134 | def __init__(self, id_to_label: Optional[Dict[int, str]] = None):
|
34 | 135 | self._id_to_label = id_to_label
|
|
0 commit comments