Skip to content

Commit eefe3a9

Browse files
KEHANGt5-copybara
authored andcommitted
Introduces seqio.CollectingMetric class
PiperOrigin-RevId: 466117278
1 parent 421f9c3 commit eefe3a9

File tree

2 files changed

+321
-2
lines changed

2 files changed

+321
-2
lines changed

t5/evaluation/metrics.py

+109-1
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,28 @@
2424
import itertools
2525
import re
2626
import string
27-
from typing import Dict, Mapping, Optional, Sequence, Tuple, Union
27+
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union
2828

2929
from absl import logging
3030
import editdistance
31+
import flax
32+
import jax.numpy as jnp
3133
import numpy as np
3234
import sacrebleu
3335
import scipy.stats
36+
import seqio
3437
import sklearn.metrics
3538
from t5.evaluation import qa_utils
39+
import tensorflow.compat.v2 as tf
3640

3741
from rouge_score import rouge_scorer
3842
from rouge_score import scoring
3943

4044

45+
ModelOutputType = seqio.metrics.ModelOutputType
46+
CollectingMetric = seqio.metrics.CollectingMetric
47+
48+
4149
def bleu(targets, predictions, tokenizer="intl"):
4250
"""Computes BLEU score.
4351
@@ -643,3 +651,103 @@ def edit_distance(targets, predictions, lower=True):
643651
"mean_edit": np.mean(edit_distances),
644652
"median_edit": np.median(edit_distances),
645653
"sum_edit": sum(edit_distances)}
654+
655+
656+
@flax.struct.dataclass
657+
class ShardedSquad(seqio.metrics.Metric):
658+
"""Implements SQuAD metrics, maximizing over answers per question."""
659+
660+
f1: float = 0.0
661+
em: float = 0.0
662+
count: int = 0
663+
model_output_type: ModelOutputType = ModelOutputType.PREDICTION
664+
665+
@classmethod
666+
def empty(cls) -> "ShardedSquad":
667+
return cls(f1=0.0, em=0.0, count=0)
668+
669+
@classmethod
670+
def from_model_output(
671+
cls,
672+
inputs: Sequence[Mapping[str, Any]],
673+
model_output: np.ndarray,
674+
features: Mapping[str, seqio.Feature],
675+
target_field_name: str = "targets",
676+
mask: Optional[np.ndarray] = None,
677+
indices_2d: Optional[np.ndarray] = None) -> "ShardedSquad":
678+
679+
del indices_2d
680+
if mask is None:
681+
mask = jnp.ones((len(inputs),))
682+
683+
# Postprocesses the targets here.
684+
postprocessed_targets = [[
685+
tf.compat.as_text(answers) for answers in example["answers"]
686+
] for example, included in zip(inputs, mask) if included]
687+
688+
# Decodes the predictions here.
689+
vocab = features[target_field_name].vocabulary
690+
predictions = [
691+
vocab.decode(tokens)
692+
for tokens, included in zip(model_output, mask)
693+
if included
694+
]
695+
696+
squad_result = squad(targets=postprocessed_targets, predictions=predictions)
697+
return cls(f1=squad_result["f1"], em=squad_result["em"], count=mask.sum())
698+
699+
def merge(self, other: "ShardedSquad") -> "ShardedSquad":
700+
"""Returns `Squad` that is the accumulation of `self` and `other`.
701+
702+
Args:
703+
other: A `Squad` whose inermediate values should be accumulated onto the
704+
values of `self`. Note that in a distributed setting, `other` will
705+
typically be the output of a `jax.lax` parallel operator and thus have a
706+
dimension added to the dataclass returned by `.from_model_output()`.
707+
708+
Returns:
709+
A new `Squad` that accumulates the value from both `self` and `other`.
710+
"""
711+
count = self.count + other.count
712+
f1 = (self.f1 * self.count + other.f1 * other.count)/count
713+
em = (self.em * self.count + other.em * other.count)/count
714+
715+
return type(self)(f1=f1, em=em, count=count)
716+
717+
def compute(self):
718+
return {"f1": self.f1, "em": self.em}
719+
720+
721+
@flax.struct.dataclass
722+
class PassthroughSquad(CollectingMetric):
723+
"""Implements SQuAD metrics, maximizing over answers per question."""
724+
725+
model_output_type: ModelOutputType = ModelOutputType.PREDICTION
726+
727+
def actual_compute(self, task_dataset_as_numpy, task_output_features,
728+
target_field_name: str = "targets"):
729+
# Postprocesses the targets here.
730+
postprocessed_targets = [[
731+
tf.compat.as_text(answers) for answers in example["answers"]
732+
] for example in task_dataset_as_numpy]
733+
734+
# We process the model outputs here by the steps below.
735+
# Step 1: removes padded examples using mask.
736+
indices_2d = self.values["indices_2d"][self.values["mask"] == 1]
737+
model_output = self.values["model_output"][self.values["mask"] == 1]
738+
assert len(postprocessed_targets) == len(indices_2d)
739+
740+
# Step 2: sorts the model outputs by 2d-indices, namely (shard_id,
741+
# index_within_shard) to align with targets.
742+
permutation = np.lexsort((indices_2d[:, 1], indices_2d[:, 0]))
743+
model_output = [
744+
model_output[permutation[i]] for i in range(len(permutation))
745+
]
746+
747+
# Decodes the predictions here.
748+
target_vocab = task_output_features[target_field_name].vocabulary
749+
predictions = [
750+
target_vocab.decode(tokens) for tokens in model_output
751+
]
752+
753+
return squad(postprocessed_targets, predictions), None

t5/evaluation/metrics_test.py

+212-1
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414

1515
"""Tests for t5.evaluation.metrics."""
1616

17+
from unittest import mock
18+
1719
from absl.testing import absltest
20+
import numpy as np
21+
import seqio
1822
import sklearn.metrics
19-
2023
from t5.evaluation import metrics
2124
from t5.evaluation import test_utils
2225

@@ -706,5 +709,213 @@ def test_edit_distance(self):
706709
})
707710

708711

712+
def mock_decode(self, ids):
713+
decode_dict = {v: k for k, v in self._encode_dict.items()}
714+
words = [decode_dict[token] for token in ids if token != 0]
715+
return " ".join(words)
716+
717+
718+
class PassthroughSquadTest(test_utils.BaseMetricsTest):
719+
720+
def test_same(self):
721+
ref = "this is a string"
722+
inputs = [{"answers": ["", ref]}, {"answers": [ref, ref]}]
723+
724+
with mock.patch.object(
725+
seqio.test_utils.MockVocabulary, "decode", new=mock_decode):
726+
vocabulary = seqio.test_utils.MockVocabulary(
727+
{
728+
"this": 2,
729+
"is": 3,
730+
"a": 4,
731+
"string": 5
732+
}, vocab_size=10)
733+
734+
model_output = np.array([[2, 3, 4, 5], [2, 3, 4, 5]])
735+
features = {"targets": seqio.Feature(vocabulary)}
736+
metric = metrics.PassthroughSquad.from_model_output(
737+
inputs, model_output, features)
738+
self.assertDictClose(metric.actual_compute(inputs, features)[0],
739+
{"em": 100, "f1": 100})
740+
741+
def test_different(self):
742+
ref = "this is a string"
743+
inputs = [{"answers": [ref, ref]}, {"answers": [ref, ref]}]
744+
745+
with mock.patch.object(
746+
seqio.test_utils.MockVocabulary, "decode", new=mock_decode):
747+
vocabulary = seqio.test_utils.MockVocabulary(
748+
{
749+
"this": 2,
750+
"is": 3,
751+
"a": 4,
752+
"string": 5,
753+
"": 6
754+
}, vocab_size=10)
755+
756+
model_output = np.array([[6], [6]])
757+
features = {"targets": seqio.Feature(vocabulary)}
758+
metric = metrics.PassthroughSquad.from_model_output(
759+
inputs, model_output, features)
760+
self.assertDictClose(metric.actual_compute(inputs, features)[0],
761+
{"em": 0, "f1": 0})
762+
763+
def test_big(self):
764+
inputs = [
765+
{"answers": ["big moose", "hippo"]},
766+
{"answers": ["correct1"]},
767+
{"answers": ["correct2.1", "correct2.2"]},
768+
{"answers": ["a", "b"]},
769+
]
770+
771+
with mock.patch.object(
772+
seqio.test_utils.MockVocabulary, "decode", new=mock_decode):
773+
vocabulary = seqio.test_utils.MockVocabulary(
774+
{
775+
"‘a": 2,
776+
"big": 3,
777+
"Moose!‘": 4,
778+
"wrong": 5,
779+
"correct2.2": 6,
780+
"c": 7
781+
}, vocab_size=10)
782+
783+
model_output = np.array([[2, 3, 4], [5, 0, 0], [6, 0, 0], [7, 0, 0]])
784+
features = {"targets": seqio.Feature(vocabulary)}
785+
metric = metrics.PassthroughSquad.from_model_output(
786+
inputs, model_output, features)
787+
self.assertDictClose(metric.actual_compute(inputs, features)[0],
788+
{"em": 25., "f1": 35.}, places=2)
789+
790+
def test_small(self):
791+
inputs = [{"answers": ["abc abd", "$$$$"]}]
792+
793+
with mock.patch.object(
794+
seqio.test_utils.MockVocabulary, "decode", new=mock_decode):
795+
vocabulary = seqio.test_utils.MockVocabulary({"abd": 2}, vocab_size=10)
796+
797+
model_output = np.array([[2]])
798+
features = {"targets": seqio.Feature(vocabulary)}
799+
metric = metrics.PassthroughSquad.from_model_output(
800+
inputs, model_output, features)
801+
self.assertDictClose(metric.actual_compute(inputs, features)[0],
802+
{"f1": 100 * 2.0 / 3.0, "em": 0.})
803+
804+
805+
class ShardedSquadTest(test_utils.BaseMetricsTest):
806+
807+
def test_same(self):
808+
ref = "this is a string"
809+
inputs = [{"answers": ["", ref]}, {"answers": [ref, ref]}]
810+
811+
with mock.patch.object(
812+
seqio.test_utils.MockVocabulary, "decode", new=mock_decode):
813+
vocabulary = seqio.test_utils.MockVocabulary(
814+
{
815+
"this": 2,
816+
"is": 3,
817+
"a": 4,
818+
"string": 5
819+
}, vocab_size=10)
820+
821+
model_output = np.array([[2, 3, 4, 5], [2, 3, 4, 5]])
822+
features = {"targets": seqio.Feature(vocabulary)}
823+
metric = metrics.ShardedSquad.from_model_output(
824+
inputs, model_output, features)
825+
self.assertDictClose(metric.compute(), {"em": 100, "f1": 100})
826+
827+
def test_different(self):
828+
ref = "this is a string"
829+
inputs = [{"answers": [ref, ref]}, {"answers": [ref, ref]}]
830+
831+
with mock.patch.object(
832+
seqio.test_utils.MockVocabulary, "decode", new=mock_decode):
833+
vocabulary = seqio.test_utils.MockVocabulary(
834+
{
835+
"this": 2,
836+
"is": 3,
837+
"a": 4,
838+
"string": 5,
839+
"": 6
840+
}, vocab_size=10)
841+
842+
model_output = np.array([[6], [6]])
843+
features = {"targets": seqio.Feature(vocabulary)}
844+
metric = metrics.ShardedSquad.from_model_output(
845+
inputs, model_output, features)
846+
self.assertDictClose(metric.compute(), {"em": 0, "f1": 0})
847+
848+
def test_big(self):
849+
inputs = [
850+
{"answers": ["big moose", "hippo"]},
851+
{"answers": ["correct1"]},
852+
{"answers": ["correct2.1", "correct2.2"]},
853+
{"answers": ["a", "b"]},
854+
]
855+
856+
with mock.patch.object(
857+
seqio.test_utils.MockVocabulary, "decode", new=mock_decode):
858+
vocabulary = seqio.test_utils.MockVocabulary(
859+
{
860+
"‘a": 2,
861+
"big": 3,
862+
"Moose!‘": 4,
863+
"wrong": 5,
864+
"correct2.2": 6,
865+
"c": 7
866+
}, vocab_size=10)
867+
868+
model_output = np.array([[2, 3, 4], [5, 0, 0], [6, 0, 0], [7, 0, 0]])
869+
features = {"targets": seqio.Feature(vocabulary)}
870+
metric = metrics.ShardedSquad.from_model_output(
871+
inputs, model_output, features)
872+
self.assertDictClose(metric.compute(), {"em": 25., "f1": 35.}, places=2)
873+
874+
def test_small(self):
875+
inputs = [{"answers": ["abc abd", "$$$$"]}]
876+
877+
with mock.patch.object(
878+
seqio.test_utils.MockVocabulary, "decode", new=mock_decode):
879+
vocabulary = seqio.test_utils.MockVocabulary({"abd": 2}, vocab_size=10)
880+
881+
model_output = np.array([[2]])
882+
features = {"targets": seqio.Feature(vocabulary)}
883+
metric = metrics.ShardedSquad.from_model_output(
884+
inputs, model_output, features)
885+
self.assertDictClose(metric.compute(), {"f1": 100 * 2.0 / 3.0, "em": 0.})
886+
887+
def test_batch_update(self):
888+
inputs1 = [
889+
{"answers": ["big moose", "hippo"]},
890+
{"answers": ["correct1"]}
891+
]
892+
inputs2 = [
893+
{"answers": ["correct2.1", "correct2.2"]},
894+
{"answers": ["a", "b"]},
895+
]
896+
897+
with mock.patch.object(
898+
seqio.test_utils.MockVocabulary, "decode", new=mock_decode):
899+
vocabulary = seqio.test_utils.MockVocabulary(
900+
{
901+
"‘a": 2,
902+
"big": 3,
903+
"Moose!‘": 4,
904+
"wrong": 5,
905+
"correct2.2": 6,
906+
"c": 7
907+
}, vocab_size=10)
908+
909+
model_output1 = np.array([[2, 3, 4], [5, 0, 0]])
910+
model_output2 = np.array([[6], [7]])
911+
features = {"targets": seqio.Feature(vocabulary)}
912+
metric1 = metrics.ShardedSquad.from_model_output(
913+
inputs1, model_output1, features)
914+
metric2 = metrics.ShardedSquad.from_model_output(
915+
inputs2, model_output2, features)
916+
metric = metric1.merge(metric2)
917+
self.assertDictClose(metric.compute(), {"em": 25., "f1": 35.}, places=2)
918+
919+
709920
if __name__ == "__main__":
710921
absltest.main()

0 commit comments

Comments
 (0)