Skip to content

Commit 153f6e0

Browse files
committed
DOC improve make_imbalance example
1 parent 45b538c commit 153f6e0

File tree

1 file changed

+18
-22
lines changed

1 file changed

+18
-22
lines changed

Diff for: examples/datasets/plot_make_imbalance.py

+18-22
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""
2-
===========================
3-
make_imbalance function
4-
===========================
2+
============================
3+
Create an imbalanced dataset
4+
============================
55
6-
An illustration of the make_imbalance function
6+
An illustration of the :func:`imblearn.datasets.make_imbalance` function to
7+
create an imbalanced dataset from a balanced dataset. We show the ability of
8+
:func:`imblearn.datasets.make_imbalance` of dealing with Pandas DataFrame.
79
810
"""
911

@@ -14,36 +16,29 @@
1416

1517
from collections import Counter
1618

19+
import pandas as pd
1720
import matplotlib.pyplot as plt
21+
1822
from sklearn.datasets import make_moons
1923

2024
from imblearn.datasets import make_imbalance
2125

2226
print(__doc__)
2327

24-
25-
def plot_decoration(ax):
26-
ax.spines['top'].set_visible(False)
27-
ax.spines['right'].set_visible(False)
28-
ax.get_xaxis().tick_bottom()
29-
ax.get_yaxis().tick_left()
30-
ax.spines['left'].set_position(('outward', 10))
31-
ax.spines['bottom'].set_position(('outward', 10))
32-
ax.set_xlim([-4, 4])
33-
34-
3528
# Generate the dataset
3629
X, y = make_moons(n_samples=200, shuffle=True, noise=0.5, random_state=10)
30+
X = pd.DataFrame(X, columns=["feature 1", "feature 2"])
3731

3832
# Two subplots, unpack the axes array immediately
3933
f, axs = plt.subplots(2, 3)
4034

4135
axs = [a for ax in axs for a in ax]
4236

43-
axs[0].scatter(X[y == 0, 0], X[y == 0, 1], label="Class #0", alpha=0.5)
44-
axs[0].scatter(X[y == 1, 0], X[y == 1, 1], label="Class #1", alpha=0.5)
37+
X.plot.scatter(
38+
x='feature 1', y='feature 2', c=y, ax=axs[0], colormap='viridis',
39+
colorbar=False
40+
)
4541
axs[0].set_title('Original set')
46-
plot_decoration(axs[0])
4742

4843

4944
def ratio_func(y, multiplier, minority_class):
@@ -58,10 +53,11 @@ def ratio_func(y, multiplier, minority_class):
5853
X_, y_ = make_imbalance(X, y, sampling_strategy=ratio_func,
5954
**{"multiplier": multiplier,
6055
"minority_class": 1})
61-
ax.scatter(X_[y_ == 0, 0], X_[y_ == 0, 1], label="Class #0", alpha=0.5)
62-
ax.scatter(X_[y_ == 1, 0], X_[y_ == 1, 1], label="Class #1", alpha=0.5)
63-
ax.set_title('sampling_strategy = {}'.format(multiplier))
64-
plot_decoration(ax)
56+
X_.plot.scatter(
57+
x='feature 1', y='feature 2', c=y_, ax=ax, colormap='viridis',
58+
colorbar=False
59+
)
60+
ax.set_title('Sampling ratio = {}'.format(multiplier))
6561

6662
plt.tight_layout()
6763
plt.show()

0 commit comments

Comments
 (0)