Skip to content

Commit baa6560

Browse files
authored
Remove scikit-learn by implementing precision_recall_fscore_support (#1557)
1 parent 0552dbe commit baa6560

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

setup.py

-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ def _parse_requirements_file(file_path):
148148
"transformers<4.37",
149149
"datasets<2.16",
150150
"accelerate<0.26",
151-
"scikit-learn",
152151
"seqeval",
153152
]
154153
_sentence_transformers_integration_deps = ["optimum-deepsparse"] + _torch_deps

src/deepsparse/transformers/metrics.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import numpy
2222

2323
from scipy.special import log_softmax
24-
from sklearn.metrics import precision_recall_fscore_support
2524

2625

2726
__all__ = [
@@ -30,6 +29,60 @@
3029
]
3130

3231

32+
def precision_recall_fscore_support(true_labels, predicted_labels, beta=1.0):
33+
"""
34+
Calculate precision, recall, and F-beta score for each class.
35+
36+
Parameters:
37+
true_labels (array-like): True labels of the data.
38+
predicted_labels (array-like): Predicted labels by the classifier.
39+
beta (float): The strength of recall versus precision in the F-score.
40+
41+
Returns:
42+
precision (numpy.ndarray): Precision for each class.
43+
recall (numpy.ndarray): Recall for each class.
44+
fscore (numpy.ndarray): F-beta score for each class.
45+
support (numpy.ndarray): Number of occurrences of each class in true_labels.
46+
"""
47+
true_labels = numpy.array(true_labels)
48+
predicted_labels = numpy.array(predicted_labels)
49+
50+
unique_labels = numpy.unique(numpy.concatenate([true_labels, predicted_labels]))
51+
precision = numpy.zeros(len(unique_labels))
52+
recall = numpy.zeros(len(unique_labels))
53+
fscore = numpy.zeros(len(unique_labels))
54+
support = numpy.zeros(len(unique_labels))
55+
56+
for i, label in enumerate(unique_labels):
57+
true_positive = numpy.sum((predicted_labels == label) & (true_labels == label))
58+
false_positive = numpy.sum((predicted_labels == label) & (true_labels != label))
59+
false_negative = numpy.sum((predicted_labels != label) & (true_labels == label))
60+
61+
precision[i] = (
62+
true_positive / (true_positive + false_positive)
63+
if true_positive + false_positive > 0
64+
else 0
65+
)
66+
recall[i] = (
67+
true_positive / (true_positive + false_negative)
68+
if true_positive + false_negative > 0
69+
else 0
70+
)
71+
fscore[i] = (
72+
(
73+
(1 + beta**2)
74+
* precision[i]
75+
* recall[i]
76+
/ (beta**2 * precision[i] + recall[i])
77+
)
78+
if precision[i] + recall[i] > 0
79+
else 0
80+
)
81+
support[i] = numpy.sum(true_labels == label)
82+
83+
return precision, recall, fscore, support
84+
85+
3386
class Perplexity:
3487
def __init__(self, accumulate: bool = False):
3588
"""

0 commit comments

Comments
 (0)