|
14 | 14 |
|
15 | 15 | """Tests for t5.evaluation.metrics."""
|
16 | 16 |
|
| 17 | +from unittest import mock |
| 18 | + |
17 | 19 | from absl.testing import absltest
|
| 20 | +import numpy as np |
| 21 | +import seqio |
18 | 22 | import sklearn.metrics
|
19 |
| - |
20 | 23 | from t5.evaluation import metrics
|
21 | 24 | from t5.evaluation import test_utils
|
22 | 25 |
|
@@ -706,5 +709,213 @@ def test_edit_distance(self):
|
706 | 709 | })
|
707 | 710 |
|
708 | 711 |
|
| 712 | +def mock_decode(self, ids): |
| 713 | + decode_dict = {v: k for k, v in self._encode_dict.items()} |
| 714 | + words = [decode_dict[token] for token in ids if token != 0] |
| 715 | + return " ".join(words) |
| 716 | + |
| 717 | + |
| 718 | +class PassthroughSquadTest(test_utils.BaseMetricsTest): |
| 719 | + |
| 720 | + def test_same(self): |
| 721 | + ref = "this is a string" |
| 722 | + inputs = [{"answers": ["", ref]}, {"answers": [ref, ref]}] |
| 723 | + |
| 724 | + with mock.patch.object( |
| 725 | + seqio.test_utils.MockVocabulary, "decode", new=mock_decode): |
| 726 | + vocabulary = seqio.test_utils.MockVocabulary( |
| 727 | + { |
| 728 | + "this": 2, |
| 729 | + "is": 3, |
| 730 | + "a": 4, |
| 731 | + "string": 5 |
| 732 | + }, vocab_size=10) |
| 733 | + |
| 734 | + model_output = np.array([[2, 3, 4, 5], [2, 3, 4, 5]]) |
| 735 | + features = {"targets": seqio.Feature(vocabulary)} |
| 736 | + metric = metrics.PassthroughSquad.from_model_output( |
| 737 | + inputs, model_output, features) |
| 738 | + self.assertDictClose(metric.actual_compute(inputs, features)[0], |
| 739 | + {"em": 100, "f1": 100}) |
| 740 | + |
| 741 | + def test_different(self): |
| 742 | + ref = "this is a string" |
| 743 | + inputs = [{"answers": [ref, ref]}, {"answers": [ref, ref]}] |
| 744 | + |
| 745 | + with mock.patch.object( |
| 746 | + seqio.test_utils.MockVocabulary, "decode", new=mock_decode): |
| 747 | + vocabulary = seqio.test_utils.MockVocabulary( |
| 748 | + { |
| 749 | + "this": 2, |
| 750 | + "is": 3, |
| 751 | + "a": 4, |
| 752 | + "string": 5, |
| 753 | + "": 6 |
| 754 | + }, vocab_size=10) |
| 755 | + |
| 756 | + model_output = np.array([[6], [6]]) |
| 757 | + features = {"targets": seqio.Feature(vocabulary)} |
| 758 | + metric = metrics.PassthroughSquad.from_model_output( |
| 759 | + inputs, model_output, features) |
| 760 | + self.assertDictClose(metric.actual_compute(inputs, features)[0], |
| 761 | + {"em": 0, "f1": 0}) |
| 762 | + |
| 763 | + def test_big(self): |
| 764 | + inputs = [ |
| 765 | + {"answers": ["big moose", "hippo"]}, |
| 766 | + {"answers": ["correct1"]}, |
| 767 | + {"answers": ["correct2.1", "correct2.2"]}, |
| 768 | + {"answers": ["a", "b"]}, |
| 769 | + ] |
| 770 | + |
| 771 | + with mock.patch.object( |
| 772 | + seqio.test_utils.MockVocabulary, "decode", new=mock_decode): |
| 773 | + vocabulary = seqio.test_utils.MockVocabulary( |
| 774 | + { |
| 775 | + "‘a": 2, |
| 776 | + "big": 3, |
| 777 | + "Moose!‘": 4, |
| 778 | + "wrong": 5, |
| 779 | + "correct2.2": 6, |
| 780 | + "c": 7 |
| 781 | + }, vocab_size=10) |
| 782 | + |
| 783 | + model_output = np.array([[2, 3, 4], [5, 0, 0], [6, 0, 0], [7, 0, 0]]) |
| 784 | + features = {"targets": seqio.Feature(vocabulary)} |
| 785 | + metric = metrics.PassthroughSquad.from_model_output( |
| 786 | + inputs, model_output, features) |
| 787 | + self.assertDictClose(metric.actual_compute(inputs, features)[0], |
| 788 | + {"em": 25., "f1": 35.}, places=2) |
| 789 | + |
| 790 | + def test_small(self): |
| 791 | + inputs = [{"answers": ["abc abd", "$$$$"]}] |
| 792 | + |
| 793 | + with mock.patch.object( |
| 794 | + seqio.test_utils.MockVocabulary, "decode", new=mock_decode): |
| 795 | + vocabulary = seqio.test_utils.MockVocabulary({"abd": 2}, vocab_size=10) |
| 796 | + |
| 797 | + model_output = np.array([[2]]) |
| 798 | + features = {"targets": seqio.Feature(vocabulary)} |
| 799 | + metric = metrics.PassthroughSquad.from_model_output( |
| 800 | + inputs, model_output, features) |
| 801 | + self.assertDictClose(metric.actual_compute(inputs, features)[0], |
| 802 | + {"f1": 100 * 2.0 / 3.0, "em": 0.}) |
| 803 | + |
| 804 | + |
| 805 | +class ShardedSquadTest(test_utils.BaseMetricsTest): |
| 806 | + |
| 807 | + def test_same(self): |
| 808 | + ref = "this is a string" |
| 809 | + inputs = [{"answers": ["", ref]}, {"answers": [ref, ref]}] |
| 810 | + |
| 811 | + with mock.patch.object( |
| 812 | + seqio.test_utils.MockVocabulary, "decode", new=mock_decode): |
| 813 | + vocabulary = seqio.test_utils.MockVocabulary( |
| 814 | + { |
| 815 | + "this": 2, |
| 816 | + "is": 3, |
| 817 | + "a": 4, |
| 818 | + "string": 5 |
| 819 | + }, vocab_size=10) |
| 820 | + |
| 821 | + model_output = np.array([[2, 3, 4, 5], [2, 3, 4, 5]]) |
| 822 | + features = {"targets": seqio.Feature(vocabulary)} |
| 823 | + metric = metrics.ShardedSquad.from_model_output( |
| 824 | + inputs, model_output, features) |
| 825 | + self.assertDictClose(metric.compute(), {"em": 100, "f1": 100}) |
| 826 | + |
| 827 | + def test_different(self): |
| 828 | + ref = "this is a string" |
| 829 | + inputs = [{"answers": [ref, ref]}, {"answers": [ref, ref]}] |
| 830 | + |
| 831 | + with mock.patch.object( |
| 832 | + seqio.test_utils.MockVocabulary, "decode", new=mock_decode): |
| 833 | + vocabulary = seqio.test_utils.MockVocabulary( |
| 834 | + { |
| 835 | + "this": 2, |
| 836 | + "is": 3, |
| 837 | + "a": 4, |
| 838 | + "string": 5, |
| 839 | + "": 6 |
| 840 | + }, vocab_size=10) |
| 841 | + |
| 842 | + model_output = np.array([[6], [6]]) |
| 843 | + features = {"targets": seqio.Feature(vocabulary)} |
| 844 | + metric = metrics.ShardedSquad.from_model_output( |
| 845 | + inputs, model_output, features) |
| 846 | + self.assertDictClose(metric.compute(), {"em": 0, "f1": 0}) |
| 847 | + |
| 848 | + def test_big(self): |
| 849 | + inputs = [ |
| 850 | + {"answers": ["big moose", "hippo"]}, |
| 851 | + {"answers": ["correct1"]}, |
| 852 | + {"answers": ["correct2.1", "correct2.2"]}, |
| 853 | + {"answers": ["a", "b"]}, |
| 854 | + ] |
| 855 | + |
| 856 | + with mock.patch.object( |
| 857 | + seqio.test_utils.MockVocabulary, "decode", new=mock_decode): |
| 858 | + vocabulary = seqio.test_utils.MockVocabulary( |
| 859 | + { |
| 860 | + "‘a": 2, |
| 861 | + "big": 3, |
| 862 | + "Moose!‘": 4, |
| 863 | + "wrong": 5, |
| 864 | + "correct2.2": 6, |
| 865 | + "c": 7 |
| 866 | + }, vocab_size=10) |
| 867 | + |
| 868 | + model_output = np.array([[2, 3, 4], [5, 0, 0], [6, 0, 0], [7, 0, 0]]) |
| 869 | + features = {"targets": seqio.Feature(vocabulary)} |
| 870 | + metric = metrics.ShardedSquad.from_model_output( |
| 871 | + inputs, model_output, features) |
| 872 | + self.assertDictClose(metric.compute(), {"em": 25., "f1": 35.}, places=2) |
| 873 | + |
| 874 | + def test_small(self): |
| 875 | + inputs = [{"answers": ["abc abd", "$$$$"]}] |
| 876 | + |
| 877 | + with mock.patch.object( |
| 878 | + seqio.test_utils.MockVocabulary, "decode", new=mock_decode): |
| 879 | + vocabulary = seqio.test_utils.MockVocabulary({"abd": 2}, vocab_size=10) |
| 880 | + |
| 881 | + model_output = np.array([[2]]) |
| 882 | + features = {"targets": seqio.Feature(vocabulary)} |
| 883 | + metric = metrics.ShardedSquad.from_model_output( |
| 884 | + inputs, model_output, features) |
| 885 | + self.assertDictClose(metric.compute(), {"f1": 100 * 2.0 / 3.0, "em": 0.}) |
| 886 | + |
| 887 | + def test_batch_update(self): |
| 888 | + inputs1 = [ |
| 889 | + {"answers": ["big moose", "hippo"]}, |
| 890 | + {"answers": ["correct1"]} |
| 891 | + ] |
| 892 | + inputs2 = [ |
| 893 | + {"answers": ["correct2.1", "correct2.2"]}, |
| 894 | + {"answers": ["a", "b"]}, |
| 895 | + ] |
| 896 | + |
| 897 | + with mock.patch.object( |
| 898 | + seqio.test_utils.MockVocabulary, "decode", new=mock_decode): |
| 899 | + vocabulary = seqio.test_utils.MockVocabulary( |
| 900 | + { |
| 901 | + "‘a": 2, |
| 902 | + "big": 3, |
| 903 | + "Moose!‘": 4, |
| 904 | + "wrong": 5, |
| 905 | + "correct2.2": 6, |
| 906 | + "c": 7 |
| 907 | + }, vocab_size=10) |
| 908 | + |
| 909 | + model_output1 = np.array([[2, 3, 4], [5, 0, 0]]) |
| 910 | + model_output2 = np.array([[6], [7]]) |
| 911 | + features = {"targets": seqio.Feature(vocabulary)} |
| 912 | + metric1 = metrics.ShardedSquad.from_model_output( |
| 913 | + inputs1, model_output1, features) |
| 914 | + metric2 = metrics.ShardedSquad.from_model_output( |
| 915 | + inputs2, model_output2, features) |
| 916 | + metric = metric1.merge(metric2) |
| 917 | + self.assertDictClose(metric.compute(), {"em": 25., "f1": 35.}, places=2) |
| 918 | + |
| 919 | + |
709 | 920 | if __name__ == "__main__":
|
710 | 921 | absltest.main()
|
0 commit comments