-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy path_smote_tomek.py
144 lines (115 loc) · 4.8 KB
/
_smote_tomek.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""Class to perform over-sampling using SMOTE and cleaning using Tomek
links."""
# Authors: Guillaume Lemaitre <[email protected]>
# Christos Aridas
# License: MIT
from __future__ import division
import logging
from sklearn.base import clone
from sklearn.utils import check_X_y
from ..base import BaseSampler
from ..over_sampling import SMOTE
from ..over_sampling.base import BaseOverSampler
from ..under_sampling import TomekLinks
from ..utils import check_target_type
from ..utils import Substitution
from ..utils._docstring import _random_state_docstring
@Substitution(
sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
random_state=_random_state_docstring)
class SMOTETomek(BaseSampler):
"""Class to perform over-sampling using SMOTE and cleaning using
Tomek links.
Combine over- and under-sampling using SMOTE and Tomek links.
Read more in the :ref:`User Guide <combine>`.
Parameters
----------
{sampling_strategy}
{random_state}
smote : object, optional (default=SMOTE())
The :class:`imblearn.over_sampling.SMOTE` object to use. If not given,
a :class:`imblearn.over_sampling.SMOTE` object with default parameters
will be given.
tomek : object, optional (default=Tomek())
The :class:`imblearn.under_sampling.Tomek` object to use. If not given,
a :class:`imblearn.under_sampling.Tomek` object with default parameters
will be given.
ratio : str, dict, or callable
.. deprecated:: 0.4
Use the parameter ``sampling_strategy`` instead. It will be removed
in 0.6.
Notes
-----
The methos is presented in [1]_.
Supports multi-class resampling. Refer to SMOTE and TomekLinks regarding
the scheme which used.
See :ref:`sphx_glr_auto_examples_combine_plot_smote_tomek.py` and
:ref:`sphx_glr_auto_examples_combine_plot_comparison_combine.py`.
See also
--------
SMOTEENN : Over-sample using SMOTE followed by under-sampling using Edited
Nearest Neighbours.
References
----------
.. [1] G. Batista, B. Bazzan, M. Monard, "Balancing Training Data for
Automated Annotation of Keywords: a Case Study," In WOB, 10-18, 2003.
Examples
--------
>>> from collections import Counter
>>> from sklearn.datasets import make_classification
>>> from imblearn.combine import \
SMOTETomek # doctest: +NORMALIZE_WHITESPACE
>>> X, y = make_classification(n_classes=2, class_sep=2,
... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0,
... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10)
>>> print('Original dataset shape %s' % Counter(y))
Original dataset shape Counter({{1: 900, 0: 100}})
>>> smt = SMOTETomek(random_state=42)
>>> X_res, y_res = smt.fit_resample(X, y)
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{0: 900, 1: 900}})
"""
_sampling_type = 'over-sampling'
def __init__(self,
sampling_strategy='auto',
random_state=None,
smote=None,
tomek=None,
ratio=None):
super(SMOTETomek, self).__init__()
self.sampling_strategy = sampling_strategy
self.random_state = random_state
self.smote = smote
self.tomek = tomek
self.ratio = ratio
self.logger = logging.getLogger(__name__)
def _validate_estimator(self):
"Private function to validate SMOTE and ENN objects"
if self.smote is not None:
if isinstance(self.smote, SMOTE):
self.smote_ = clone(self.smote)
else:
raise ValueError('smote needs to be a SMOTE object.'
'Got {} instead.'.format(type(self.smote)))
# Otherwise create a default SMOTE
else:
self.smote_ = SMOTE(
sampling_strategy=self.sampling_strategy,
random_state=self.random_state,
ratio=self.ratio)
if self.tomek is not None:
if isinstance(self.tomek, TomekLinks):
self.tomek_ = clone(self.tomek)
else:
raise ValueError('tomek needs to be a TomekLinks object.'
'Got {} instead.'.format(type(self.tomek)))
# Otherwise create a default TomekLinks
else:
self.tomek_ = TomekLinks(sampling_strategy='all')
def _fit_resample(self, X, y, sample_weight=None):
self._validate_estimator()
y = check_target_type(y)
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
self.sampling_strategy_ = self.sampling_strategy
resampled_arrays = self.smote_.fit_resample(X, y, sample_weight)
return self.tomek_.fit_resample(*resampled_arrays)