-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy pathutils.py
78 lines (61 loc) · 2.21 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import warnings
import numpy as np
from sklearn.exceptions import DataConversionWarning
from sklearn.utils.multiclass import _is_integral_float
def is_multilabel(y):
if not (y.ndim == 2 and y.shape[1] > 1):
return False
if hasattr(y, "unique"):
labels = np.asarray(y.unique())
else:
labels = np.unique(y).compute()
return len(labels) < 3 and (
y.dtype.kind in 'biu' or _is_integral_float(labels)
)
def type_of_target(y):
if is_multilabel(y):
return 'multilabel-indicator'
if y.ndim > 2:
return 'unknown'
if y.ndim == 2 and y.shape[1] == 0:
return 'unknown' # [[]]
if y.ndim == 2 and y.shape[1] > 1:
# [[1, 2], [1, 2]]
suffix = "-multioutput"
else:
# [1, 2, 3] or [[1], [2], [3]]
suffix = ""
# check float and contains non-integer float values
if y.dtype.kind == 'f' and np.any(y != y.astype(int)):
# [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.]
# NOTE: we don't check for infinite values
return 'continuous' + suffix
if hasattr(y, "unique"):
labels = np.asarray(y.unique())
else:
labels = np.unique(y).compute()
if (len((labels)) > 2) or (y.ndim >= 2 and len(y[0]) > 1):
# [1, 2, 3] or [[1., 2., 3]] or [[1, 2]]
return 'multiclass' + suffix
# [1, 2] or [["a"], ["b"]]
return 'binary'
def column_or_1d(y, *, warn=False):
shape = y.shape
if len(shape) == 1:
return y.ravel()
if len(shape) == 2 and shape[1] == 1:
if warn:
warnings.warn(
"A column-vector y was passed when a 1d array was expected. "
"Please change the shape of y to (n_samples, ), for example "
"using ravel().", DataConversionWarning, stacklevel=2
)
return y.ravel()
raise ValueError(
f"y should be a 1d array. Got an array of shape {shape} instead."
)
def check_classification_targets(y):
y_type = type_of_target(y)
if y_type not in ['binary', 'multiclass', 'multiclass-multioutput',
'multilabel-indicator', 'multilabel-sequences']:
raise ValueError("Unknown label type: %r" % y_type)