Skip to content

Commit aa1add3

Browse files
authored
FEA allow any resampler in the BalancedBaggingClassifier (#808)
1 parent 76abfa4 commit aa1add3

File tree

8 files changed

+503
-96
lines changed

8 files changed

+503
-96
lines changed

Diff for: doc/bibtex/refs.bib

+29
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,32 @@ @article{wilson1997improved
244244
pages={1--34},
245245
year={1997}
246246
}
247+
248+
@inproceedings{wang2009diversity,
249+
title={Diversity analysis on imbalanced data sets by using ensemble models},
250+
author={Wang, Shuo and Yao, Xin},
251+
booktitle={2009 IEEE symposium on computational intelligence and data mining},
252+
pages={324--331},
253+
year={2009},
254+
organization={IEEE}
255+
}
256+
257+
@article{hido2009roughly,
258+
title={Roughly balanced bagging for imbalanced data},
259+
author={Hido, Shohei and Kashima, Hisashi and Takahashi, Yutaka},
260+
journal={Statistical Analysis and Data Mining: The ASA Data Science Journal},
261+
volume={2},
262+
number={5-6},
263+
pages={412--426},
264+
year={2009},
265+
publisher={Wiley Online Library}
266+
}
267+
268+
@article{maclin1997empirical,
269+
title={An empirical evaluation of bagging and boosting},
270+
author={Maclin, Richard and Opitz, David},
271+
journal={AAAI/IAAI},
272+
volume={1997},
273+
pages={546--551},
274+
year={1997}
275+
}

Diff for: doc/ensemble.rst

+22-18
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ Bagging classifier
1818

1919
In ensemble classifiers, bagging methods build several estimators on different
2020
randomly selected subset of data. In scikit-learn, this classifier is named
21-
``BaggingClassifier``. However, this classifier does not allow to balance each
22-
subset of data. Therefore, when training on imbalanced data set, this
23-
classifier will favor the majority classes::
21+
:class:`~sklearn.ensemble.BaggingClassifier`. However, this classifier does not
22+
allow to balance each subset of data. Therefore, when training on imbalanced
23+
data set, this classifier will favor the majority classes::
2424

2525
>>> from sklearn.datasets import make_classification
2626
>>> X, y = make_classification(n_samples=10000, n_features=2, n_informative=2,
@@ -41,14 +41,13 @@ classifier will favor the majority classes::
4141
>>> balanced_accuracy_score(y_test, y_pred) # doctest: +ELLIPSIS
4242
0.77...
4343

44-
:class:`BalancedBaggingClassifier` allows to resample each subset of data
45-
before to train each estimator of the ensemble. In short, it combines the
46-
output of an :class:`EasyEnsemble` sampler with an ensemble of classifiers
47-
(i.e. ``BaggingClassifier``). Therefore, :class:`BalancedBaggingClassifier`
48-
takes the same parameters than the scikit-learn
49-
``BaggingClassifier``. Additionally, there is two additional parameters,
50-
``sampling_strategy`` and ``replacement`` to control the behaviour of the
51-
random under-sampler::
44+
In :class:`BalancedBaggingClassifier`, each bootstrap sample will be further
45+
resampled to achieve the `sampling_strategy` desired. Therefore,
46+
:class:`BalancedBaggingClassifier` takes the same parameters than the
47+
scikit-learn :class:`~sklearn.ensemble.BaggingClassifier`. In addition, the
48+
sampling is controlled by the parameter `sampler` or the two parameters
49+
`sampling_strategy` and `replacement`, if one wants to use the
50+
:class:`~imblearn.under_sampling.RandomUnderSampler`::
5251

5352
>>> from imblearn.ensemble import BalancedBaggingClassifier
5453
>>> bbc = BalancedBaggingClassifier(base_estimator=DecisionTreeClassifier(),
@@ -61,6 +60,12 @@ random under-sampler::
6160
>>> balanced_accuracy_score(y_test, y_pred) # doctest: +ELLIPSIS
6261
0.8...
6362

63+
Changing the `sampler` will give rise to different known implementation
64+
:cite:`maclin1997empirical`, :cite:`hido2009roughly`,
65+
:cite:`wang2009diversity`. You can refer to the following example shows in
66+
practice these different methods:
67+
:ref:`sphx_glr_auto_examples_ensemble_plot_bagging_classifier.py`
68+
6469
.. _forest:
6570

6671
Forest of randomized trees
@@ -69,8 +74,7 @@ Forest of randomized trees
6974
:class:`BalancedRandomForestClassifier` is another ensemble method in which
7075
each tree of the forest will be provided a balanced bootstrap sample
7176
:cite:`chen2004using`. This class provides all functionality of the
72-
:class:`~sklearn.ensemble.RandomForestClassifier` and notably the
73-
`feature_importances_` attributes::
77+
:class:`~sklearn.ensemble.RandomForestClassifier`::
7478

7579
>>> from imblearn.ensemble import BalancedRandomForestClassifier
7680
>>> brf = BalancedRandomForestClassifier(n_estimators=100, random_state=0)
@@ -99,11 +103,11 @@ a boosting iteration :cite:`seiffert2009rusboost`::
99103
>>> balanced_accuracy_score(y_test, y_pred) # doctest: +ELLIPSIS
100104
0...
101105

102-
A specific method which uses ``AdaBoost`` as learners in the bagging classifier
103-
is called EasyEnsemble. The :class:`EasyEnsembleClassifier` allows to bag
104-
AdaBoost learners which are trained on balanced bootstrap samples
105-
:cite:`liu2008exploratory`. Similarly to the :class:`BalancedBaggingClassifier`
106-
API, one can construct the ensemble as::
106+
A specific method which uses :class:`~sklearn.ensemble.AdaBoostClassifier` as
107+
learners in the bagging classifier is called "EasyEnsemble". The
108+
:class:`EasyEnsembleClassifier` allows to bag AdaBoost learners which are
109+
trained on balanced bootstrap samples :cite:`liu2008exploratory`. Similarly to
110+
the :class:`BalancedBaggingClassifier` API, one can construct the ensemble as::
107111

108112
>>> from imblearn.ensemble import EasyEnsembleClassifier
109113
>>> eec = EasyEnsembleClassifier(random_state=0)

Diff for: doc/whats_new/v0.8.rst

+5
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ New features
2424
only containing categorical features.
2525
:pr:`802` by :user:`Guillaume Lemaitre <glemaitre>`.
2626

27+
- Add the possibility to pass any type of samplers in
28+
:class:`imblearn.ensemble.BalancedBaggingClassifier` unlocking the
29+
implementation of methods based on resampled bagging.
30+
:pr:`808` by :user:`Guillaume Lemaitre <glemaitre>`.
31+
2732
Enhancements
2833
............
2934

Diff for: examples/ensemble/plot_bagging_classifier.py

+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
"""
2+
=================================
3+
Bagging classifiers using sampler
4+
=================================
5+
6+
In this example, we show how
7+
:class:`~imblearn.ensemble.BalancedBaggingClassifier` can be used to create a
8+
large variety of classifiers by giving different samplers.
9+
10+
We will give several examples that have been published in the passed year.
11+
"""
12+
13+
# Authors: Guillaume Lemaitre <[email protected]>
14+
# License: MIT
15+
16+
# %%
17+
print(__doc__)
18+
19+
# %% [markdown]
20+
# Generate an imbalanced dataset
21+
# ------------------------------
22+
#
23+
# For this example, we will create a synthetic dataset using the function
24+
# :func:`~sklearn.datasets.make_classification`. The problem will be a toy
25+
# classification problem with a ratio of 1:9 between the two classes.
26+
27+
# %%
28+
from sklearn.datasets import make_classification
29+
30+
X, y = make_classification(
31+
n_samples=10_000,
32+
n_features=10,
33+
weights=[0.1, 0.9],
34+
class_sep=0.5,
35+
random_state=0,
36+
)
37+
38+
# %%
39+
import pandas as pd
40+
41+
pd.Series(y).value_counts(normalize=True)
42+
43+
# %% [markdown]
44+
# In the following sections, we will show a couple of algorithms that have
45+
# been proposed over the years. We intend to illustrate how one can reuse the
46+
# :class:`~imblearn.ensemble.BalancedBaggingClassifier` by passing different
47+
# sampler.
48+
49+
# %%
50+
from sklearn.model_selection import cross_validate
51+
from sklearn.ensemble import BaggingClassifier
52+
53+
ebb = BaggingClassifier()
54+
cv_results = cross_validate(ebb, X, y, scoring="balanced_accuracy")
55+
56+
print(f"{cv_results['test_score'].mean():.3f} +/- {cv_results['test_score'].std():.3f}")
57+
58+
# %% [markdown]
59+
# Exactly Balanced Bagging and Over-Bagging
60+
# -----------------------------------------
61+
#
62+
# The :class:`~imblearn.ensemble.BalancedBaggingClassifier` can use in
63+
# conjunction with a :class:`~imblearn.under_sampling.RandomUnderSampler` or
64+
# :class:`~imblearn.over_sampling.RandomOverSampler`. These methods are
65+
# referred as Exactly Balanced Bagging and Over-Bagging, respectively and have
66+
# been proposed first in [1]_.
67+
68+
# %%
69+
from imblearn.ensemble import BalancedBaggingClassifier
70+
from imblearn.under_sampling import RandomUnderSampler
71+
72+
# Exactly Balanced Bagging
73+
ebb = BalancedBaggingClassifier(sampler=RandomUnderSampler())
74+
cv_results = cross_validate(ebb, X, y, scoring="balanced_accuracy")
75+
76+
print(f"{cv_results['test_score'].mean():.3f} +/- {cv_results['test_score'].std():.3f}")
77+
78+
# %%
79+
from imblearn.over_sampling import RandomOverSampler
80+
81+
# Over-bagging
82+
over_bagging = BalancedBaggingClassifier(sampler=RandomOverSampler())
83+
cv_results = cross_validate(over_bagging, X, y, scoring="balanced_accuracy")
84+
85+
print(f"{cv_results['test_score'].mean():.3f} +/- {cv_results['test_score'].std():.3f}")
86+
87+
# %% [markdown]
88+
# SMOTE-Bagging
89+
# -------------
90+
#
91+
# Instead of using a :class:`~imblearn.over_sampling.RandomOverSampler` that
92+
# make a bootstrap, an alternative is to use
93+
# :class:`~imblearn.over_sampling.SMOTE` as an over-sampler. This is known as
94+
# SMOTE-Bagging [2]_.
95+
96+
# %%
97+
from imblearn.over_sampling import SMOTE
98+
99+
# SMOTE-Bagging
100+
smote_bagging = BalancedBaggingClassifier(sampler=SMOTE())
101+
cv_results = cross_validate(smote_bagging, X, y, scoring="balanced_accuracy")
102+
103+
print(f"{cv_results['test_score'].mean():.3f} +/- {cv_results['test_score'].std():.3f}")
104+
105+
# %% [markdown]
106+
# Roughly Balanced Bagging
107+
# ------------------------
108+
# While using a :class:`~imblearn.under_sampling.RandomUnderSampler` or
109+
# :class:`~imblearn.over_sampling.RandomOverSampler` will create exactly the
110+
# desired number of samples, it does not follow the statistical spirit wanted
111+
# in the bagging framework. The authors in [3]_ proposes to use a negative
112+
# binomial distribution to compute the number of samples of the majority
113+
# class to be selected and then perform a random under-sampling.
114+
#
115+
# Here, we illustrate this method by implementing a function in charge of
116+
# resampling and use the :class:`~imblearn.FunctionSampler` to integrate it
117+
# within a :class:`~imblearn.pipeline.Pipeline` and
118+
# :class:`~sklearn.model_selection.cross_validate`.
119+
120+
# %%
121+
from collections import Counter
122+
import numpy as np
123+
from imblearn import FunctionSampler
124+
125+
126+
def roughly_balanced_bagging(X, y, replace=False):
127+
"""Implementation of Roughly Balanced Bagging for binary problem."""
128+
# find the minority and majority classes
129+
class_counts = Counter(y)
130+
majority_class = max(class_counts, key=class_counts.get)
131+
minority_class = min(class_counts, key=class_counts.get)
132+
133+
# compute the number of sample to draw from the majority class using
134+
# a negative binomial distribution
135+
n_minority_class = class_counts[minority_class]
136+
n_majority_resampled = np.random.negative_binomial(n=n_minority_class, p=0.5)
137+
138+
# draw randomly with or without replacement
139+
majority_indices = np.random.choice(
140+
np.flatnonzero(y == majority_class),
141+
size=n_majority_resampled,
142+
replace=replace,
143+
)
144+
minority_indices = np.random.choice(
145+
np.flatnonzero(y == minority_class),
146+
size=n_minority_class,
147+
replace=replace,
148+
)
149+
indices = np.hstack([majority_indices, minority_indices])
150+
151+
return X[indices], y[indices]
152+
153+
154+
# Roughly Balanced Bagging
155+
rbb = BalancedBaggingClassifier(
156+
sampler=FunctionSampler(func=roughly_balanced_bagging, kw_args={"replace": True})
157+
)
158+
cv_results = cross_validate(rbb, X, y, scoring="balanced_accuracy")
159+
160+
print(f"{cv_results['test_score'].mean():.3f} +/- {cv_results['test_score'].std():.3f}")
161+
162+
163+
# %% [markdown]
164+
# .. topic:: References:
165+
#
166+
# .. [1] R. Maclin, and D. Opitz. "An empirical evaluation of bagging and
167+
# boosting." AAAI/IAAI 1997 (1997): 546-551.
168+
#
169+
# .. [2] S. Wang, and X. Yao. "Diversity analysis on imbalanced data sets by
170+
# using ensemble models." 2009 IEEE symposium on computational
171+
# intelligence and data mining. IEEE, 2009.
172+
#
173+
# .. [3] S. Hido, H. Kashima, and Y. Takahashi. "Roughly balanced bagging
174+
# for imbalanced data." Statistical Analysis and Data Mining: The ASA
175+
# Data Science Journal 2.5‐6 (2009): 412-426.

Diff for: examples/ensemble/plot_comparison_ensemble_classifier.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33
Compare ensemble classifiers using resampling
44
=============================================
55
6-
Ensembling classifiers have shown to improve classification performance compare
6+
Ensemble classifiers have shown to improve classification performance compare
77
to single learner. However, they will be affected by class imbalance. This
88
example shows the benefit of balancing the training set before to learn
99
learners. We are making the comparison with non-balanced ensemble methods.
1010
1111
We make a comparison using the balanced accuracy and geometric mean which are
1212
metrics widely used in the literature to evaluate models learned on imbalanced
1313
set.
14-
1514
"""
1615

1716
# Authors: Guillaume Lemaitre <[email protected]>

0 commit comments

Comments
 (0)