Skip to content

Commit d105ec8

Browse files
authored
Add ST annotation to evaluators (#2586)
1 parent 99674c7 commit d105ec8

12 files changed

+31
-13
lines changed

sentence_transformers/evaluation/BinaryClassificationEvaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from sentence_transformers import SentenceTransformer
12
from contextlib import nullcontext
23
from . import SentenceEvaluator
34
import logging
@@ -111,7 +112,7 @@ def from_input_examples(cls, examples: List[InputExample], **kwargs):
111112
scores.append(example.label)
112113
return cls(sentences1, sentences2, scores, **kwargs)
113114

114-
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
115+
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
115116
if epoch != -1:
116117
if steps == -1:
117118
out_txt = f" after epoch {epoch}"

sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from contextlib import nullcontext
2+
3+
from sentence_transformers import SentenceTransformer
24
from . import SentenceEvaluator, SimilarityFunction
35
import logging
46
import os
@@ -101,7 +103,7 @@ def from_input_examples(cls, examples: List[InputExample], **kwargs):
101103
scores.append(example.label)
102104
return cls(sentences1, sentences2, scores, **kwargs)
103105

104-
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
106+
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
105107
if epoch != -1:
106108
if steps == -1:
107109
out_txt = f" after epoch {epoch}"

sentence_transformers/evaluation/InformationRetrievalEvaluator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from sentence_transformers import SentenceTransformer
12
from contextlib import nullcontext
23
from . import SentenceEvaluator
34
import torch
@@ -94,7 +95,9 @@ def __init__(
9495
for k in map_at_k:
9596
self.csv_headers.append("{}-MAP@{}".format(score_name, k))
9697

97-
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1, *args, **kwargs) -> float:
98+
def __call__(
99+
self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1, *args, **kwargs
100+
) -> float:
98101
if epoch != -1:
99102
if steps == -1:
100103
out_txt = f" after epoch {epoch}"
@@ -147,7 +150,9 @@ def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int =
147150
else:
148151
return scores[self.main_score_function]["map@k"][max(self.map_at_k)]
149152

150-
def compute_metrices(self, model, corpus_model=None, corpus_embeddings: Tensor = None) -> Dict[str, float]:
153+
def compute_metrices(
154+
self, model: SentenceTransformer, corpus_model=None, corpus_embeddings: Tensor = None
155+
) -> Dict[str, float]:
151156
if corpus_model is None:
152157
corpus_model = model
153158

sentence_transformers/evaluation/LabelAccuracyEvaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from sentence_transformers import SentenceTransformer
12
from . import SentenceEvaluator
23
import torch
34
from torch.utils.data import DataLoader
@@ -37,7 +38,7 @@ def __init__(self, dataloader: DataLoader, name: str = "", softmax_model=None, w
3738
self.csv_file = "accuracy_evaluation" + name + "_results.csv"
3839
self.csv_headers = ["epoch", "steps", "accuracy"]
3940

40-
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
41+
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
4142
model.eval()
4243
total = 0
4344
correct = 0

sentence_transformers/evaluation/MSEEvaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from sentence_transformers import SentenceTransformer
12
from contextlib import nullcontext
23
from sentence_transformers.evaluation import SentenceEvaluator
34
import logging
@@ -57,7 +58,7 @@ def __init__(
5758
self.csv_headers = ["epoch", "steps", "MSE"]
5859
self.write_csv = write_csv
5960

60-
def __call__(self, model, output_path, epoch=-1, steps=-1):
61+
def __call__(self, model: SentenceTransformer, output_path, epoch=-1, steps=-1):
6162
if epoch != -1:
6263
if steps == -1:
6364
out_txt = f" after epoch {epoch}"

sentence_transformers/evaluation/MSEEvaluatorFromDataFrame.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(
7777
all_src_embeddings = teacher_model.encode(all_source_sentences, batch_size=self.batch_size)
7878
self.teacher_embeddings = {sent: emb for sent, emb in zip(all_source_sentences, all_src_embeddings)}
7979

80-
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1):
80+
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1):
8181
model.eval()
8282

8383
mse_scores = []

sentence_transformers/evaluation/ParaphraseMiningEvaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from sentence_transformers import SentenceTransformer
12
from contextlib import nullcontext
23
from . import SentenceEvaluator
34
import logging
@@ -99,7 +100,7 @@ def __init__(
99100
self.csv_headers = ["epoch", "steps", "precision", "recall", "f1", "threshold", "average_precision"]
100101
self.write_csv = write_csv
101102

102-
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
103+
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
103104
if epoch != -1:
104105
if steps == -1:
105106
out_txt = f" after epoch {epoch}"

sentence_transformers/evaluation/RerankingEvaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from sentence_transformers import SentenceTransformer
12
from contextlib import nullcontext
23
from . import SentenceEvaluator
34
import logging
@@ -82,7 +83,7 @@ def __init__(
8283
]
8384
self.write_csv = write_csv
8485

85-
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
86+
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
8687
if epoch != -1:
8788
if steps == -1:
8889
out_txt = f" after epoch {epoch}"

sentence_transformers/evaluation/SentenceEvaluator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
from sentence_transformers import SentenceTransformer
2+
3+
14
class SentenceEvaluator:
25
"""
36
Base class for all evaluators
47
58
Extend this class and implement __call__ for custom evaluators.
69
"""
710

8-
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
11+
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
912
"""
1013
This is called during training to evaluate the model.
1114
It returns a score for the evaluation with a higher score indicating a better result.

sentence_transformers/evaluation/SequentialEvaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from sentence_transformers import SentenceTransformer
12
from . import SentenceEvaluator
23
from typing import Iterable
34

@@ -14,7 +15,7 @@ def __init__(self, evaluators: Iterable[SentenceEvaluator], main_score_function=
1415
self.evaluators = evaluators
1516
self.main_score_function = main_score_function
1617

17-
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
18+
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
1819
scores = []
1920
for evaluator in self.evaluators:
2021
scores.append(evaluator(model, output_path, epoch, steps))

sentence_transformers/evaluation/TranslationEvaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from sentence_transformers import SentenceTransformer
12
from contextlib import nullcontext
23
from . import SentenceEvaluator
34
import logging
@@ -70,7 +71,7 @@ def __init__(
7071
self.csv_headers = ["epoch", "steps", "src2trg", "trg2src"]
7172
self.write_csv = write_csv
7273

73-
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
74+
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
7475
if epoch != -1:
7576
if steps == -1:
7677
out_txt = f" after epoch {epoch}"

sentence_transformers/evaluation/TripletEvaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from sentence_transformers import SentenceTransformer
12
from contextlib import nullcontext
23
from . import SentenceEvaluator, SimilarityFunction
34
import logging
@@ -75,7 +76,7 @@ def from_input_examples(cls, examples: List[InputExample], **kwargs):
7576
negatives.append(example.texts[2])
7677
return cls(anchors, positives, negatives, **kwargs)
7778

78-
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
79+
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
7980
if epoch != -1:
8081
if steps == -1:
8182
out_txt = f" after epoch {epoch}"

0 commit comments

Comments
 (0)