Skip to content

Commit ef3f570

Browse files
Alex7Lirth
authored andcommitted
ENH Add fast kernel classifier/regressor (#13)
1 parent fa8d1fe commit ef3f570

20 files changed

+1522
-2
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,7 @@ doc/generated/
7171

7272
# PyBuilder
7373
target/
74+
75+
# Pycharm
76+
.idea
77+
venv/

benchmarks/__init__.py

Whitespace-only changes.

benchmarks/_bench/__init__.py

Whitespace-only changes.
+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import matplotlib
2+
import matplotlib.pyplot as plt
3+
import numpy as np
4+
from time import time
5+
6+
from sklearn_extra.kernel_methods import EigenProClassifier
7+
from sklearn.svm import SVC
8+
from sklearn.datasets import fetch_openml
9+
10+
rng = np.random.RandomState(1)
11+
12+
# Generate sample data from mnist
13+
mnist = fetch_openml("mnist_784")
14+
mnist.data = mnist.data / 255.0
15+
print("Data has loaded")
16+
17+
p = rng.permutation(60000)
18+
x_train = mnist.data[p]
19+
y_train = np.int32(mnist.target[p])
20+
x_test = mnist.data[60000:]
21+
y_test = np.int32(mnist.target[60000:])
22+
23+
# Run tests comparing eig to svc
24+
eig_fit_times = []
25+
eig_pred_times = []
26+
eig_err = []
27+
svc_fit_times = []
28+
svc_pred_times = []
29+
svc_err = []
30+
31+
train_sizes = [500, 1000, 2000, 5000, 10000, 20000, 40000, 60000]
32+
33+
bandwidth = 5.0
34+
35+
# Fit models to data
36+
for train_size in train_sizes:
37+
for name, estimator in [
38+
(
39+
"EigenPro",
40+
EigenProClassifier(
41+
n_epoch=2, bandwidth=bandwidth, random_state=rng
42+
),
43+
),
44+
(
45+
"SupportVector",
46+
SVC(
47+
C=5, gamma=1.0 / (2 * bandwidth * bandwidth), random_state=rng
48+
),
49+
),
50+
]:
51+
stime = time()
52+
estimator.fit(x_train[:train_size], y_train[:train_size])
53+
fit_t = time() - stime
54+
55+
stime = time()
56+
y_pred_test = estimator.predict(x_test)
57+
pred_t = time() - stime
58+
59+
err = 100.0 * np.sum(y_pred_test != y_test) / len(y_test)
60+
if name == "EigenPro":
61+
eig_fit_times.append(fit_t)
62+
eig_pred_times.append(pred_t)
63+
eig_err.append(err)
64+
else:
65+
svc_fit_times.append(fit_t)
66+
svc_pred_times.append(pred_t)
67+
svc_err.append(err)
68+
print(
69+
"%s Classification with %i training samples in %0.2f seconds."
70+
"Test error %.4f" % (name, train_size, fit_t + pred_t, err)
71+
)
72+
73+
# set up grid for figures
74+
fig = plt.figure(num=None, figsize=(6, 4), dpi=160)
75+
ax = plt.subplot2grid((2, 2), (0, 0), rowspan=2)
76+
train_size_labels = ["500", "1k", "2k", "5k", "10k", "20k", "40k", "60k"]
77+
78+
# Graph fit(train) time
79+
ax.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
80+
ax.plot(train_sizes, svc_fit_times, "o--", color="g", label="SVC")
81+
ax.plot(train_sizes, eig_fit_times, "o-", color="r", label="EigenPro")
82+
ax.set_xscale("log")
83+
ax.set_yscale("log", nonposy="clip")
84+
ax.set_xlabel("train size")
85+
ax.set_ylabel("time (seconds)")
86+
ax.legend()
87+
ax.set_title("Train set")
88+
ax.set_xticks(train_sizes)
89+
ax.set_xticks([], minor=True)
90+
ax.set_xticklabels(train_size_labels)
91+
92+
# Graph prediction(test) time
93+
ax = plt.subplot2grid((2, 2), (0, 1), rowspan=1)
94+
ax.plot(train_sizes, eig_pred_times, "o-", color="r")
95+
ax.plot(train_sizes, svc_pred_times, "o--", color="g")
96+
ax.set_xscale("log")
97+
ax.set_yscale("log", nonposy="clip")
98+
ax.set_ylabel("time (seconds)")
99+
ax.set_title("Test set")
100+
ax.set_xticks(train_sizes)
101+
ax.set_xticks([], minor=True)
102+
ax.set_xticklabels(train_size_labels)
103+
104+
# Graph training error
105+
ax = plt.subplot2grid((2, 2), (1, 1), rowspan=1)
106+
ax.plot(train_sizes, eig_err, "o-", color="r")
107+
ax.plot(train_sizes, svc_err, "o-", color="g")
108+
ax.set_xscale("log")
109+
ax.set_xticks(train_sizes)
110+
ax.set_xticklabels(train_size_labels)
111+
ax.set_xticks([], minor=True)
112+
ax.set_xlabel("train size")
113+
ax.set_ylabel("classification error %")
114+
plt.tight_layout()
115+
plt.show()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import matplotlib
2+
import matplotlib.pyplot as plt
3+
import numpy as np
4+
from time import time
5+
6+
from sklearn.datasets import fetch_openml
7+
from sklearn_extra.kernel_methods import EigenProClassifier
8+
from sklearn.svm import SVC
9+
10+
rng = np.random.RandomState(1)
11+
12+
# Generate sample data from mnist
13+
mnist = fetch_openml("mnist_784")
14+
mnist.data = mnist.data / 255.0
15+
16+
p = rng.permutation(60000)
17+
x_train = mnist.data[p][:60000]
18+
y_train = np.int32(mnist.target[p][:60000])
19+
x_test = mnist.data[60000:]
20+
y_test = np.int32(mnist.target[60000:])
21+
22+
# randomize 20% of labels
23+
p = rng.choice(len(y_train), np.int32(len(y_train) * 0.2), False)
24+
y_train[p] = rng.choice(10, np.int32(len(y_train) * 0.2))
25+
p = rng.choice(len(y_test), np.int32(len(y_test) * 0.2), False)
26+
y_test[p] = rng.choice(10, np.int32(len(y_test) * 0.2))
27+
28+
# Run tests comparing fkc to svc
29+
eig_fit_times = []
30+
eig_pred_times = []
31+
eig_err = []
32+
svc_fit_times = []
33+
svc_pred_times = []
34+
svc_err = []
35+
36+
train_sizes = [500, 1000, 2000, 5000, 10000, 20000, 40000, 60000]
37+
38+
bandwidth = 5.0
39+
# Fit models to data
40+
for train_size in train_sizes:
41+
for name, estimator in [
42+
(
43+
"EigenPro",
44+
EigenProClassifier(
45+
n_epoch=2, bandwidth=bandwidth, random_state=rng
46+
),
47+
),
48+
("SupportVector", SVC(C=5, gamma=1.0 / (2 * bandwidth * bandwidth))),
49+
]:
50+
stime = time()
51+
estimator.fit(x_train[:train_size], y_train[:train_size])
52+
fit_t = time() - stime
53+
54+
stime = time()
55+
y_pred_test = estimator.predict(x_test)
56+
pred_t = time() - stime
57+
err = 100.0 * np.sum(y_pred_test != y_test) / len(y_test)
58+
if name == "EigenPro":
59+
eig_fit_times.append(fit_t)
60+
eig_pred_times.append(pred_t)
61+
eig_err.append(err)
62+
else:
63+
svc_fit_times.append(fit_t)
64+
svc_pred_times.append(pred_t)
65+
svc_err.append(err)
66+
print(
67+
"%s Classification with %i training samples in %0.2f seconds. "
68+
"Test error %.4f" % (name, train_size, fit_t + pred_t, err)
69+
)
70+
71+
# set up grid for figures
72+
fig = plt.figure(num=None, figsize=(6, 4), dpi=160)
73+
ax = plt.subplot2grid((2, 2), (0, 0), rowspan=2)
74+
train_size_labels = ["500", "1k", "2k", "5k", "10k", "20k", "40k", "60k"]
75+
76+
# Graph fit(train) time
77+
ax.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
78+
ax.plot(train_sizes, svc_fit_times, "o--", color="g", label="SVC")
79+
ax.plot(train_sizes, eig_fit_times, "o-", color="r", label="EigenPro")
80+
ax.set_xscale("log")
81+
ax.set_yscale("log", nonposy="clip")
82+
ax.set_xlabel("train size")
83+
ax.set_ylabel("time (seconds)")
84+
ax.legend()
85+
ax.set_title("Train set")
86+
ax.set_xticks(train_sizes)
87+
ax.set_xticks([], minor=True)
88+
ax.set_xticklabels(train_size_labels)
89+
90+
# Graph prediction(test) time
91+
ax = plt.subplot2grid((2, 2), (0, 1), rowspan=1)
92+
ax.plot(train_sizes, eig_pred_times, "o-", color="r")
93+
ax.plot(train_sizes, svc_pred_times, "o--", color="g")
94+
ax.set_xscale("log")
95+
ax.set_yscale("log", nonposy="clip")
96+
ax.set_ylabel("time (seconds)")
97+
ax.set_title("Test set")
98+
ax.set_xticks(train_sizes)
99+
ax.set_xticks([], minor=True)
100+
ax.set_xticklabels(train_size_labels)
101+
102+
# Graph training error
103+
ax = plt.subplot2grid((2, 2), (1, 1), rowspan=1)
104+
ax.plot(train_sizes, eig_err, "o-", color="r")
105+
ax.plot(train_sizes, svc_err, "o-", color="g")
106+
ax.set_xscale("log")
107+
ax.set_xticks(train_sizes)
108+
ax.set_xticklabels(train_size_labels)
109+
ax.set_xticks([], minor=True)
110+
ax.set_xlabel("train size")
111+
ax.set_ylabel("classification error %")
112+
plt.tight_layout()
113+
plt.show()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import matplotlib
2+
import matplotlib.pyplot as plt
3+
import numpy as np
4+
from time import time
5+
6+
from sklearn.datasets import make_classification
7+
from sklearn_extra.kernel_methods import EigenProClassifier
8+
from sklearn.svm import SVC
9+
10+
rng = np.random.RandomState(1)
11+
12+
max_size = 50000
13+
test_size = 10000
14+
15+
# Get data for testing
16+
17+
x, y = make_classification(
18+
n_samples=max_size + test_size,
19+
n_features=400,
20+
n_informative=6,
21+
random_state=rng,
22+
)
23+
24+
x_train = x[:max_size]
25+
y_train = y[:max_size]
26+
x_test = x[max_size:]
27+
y_test = y[max_size:]
28+
29+
eig_fit_times = []
30+
eig_pred_times = []
31+
eig_err = []
32+
svc_fit_times = []
33+
svc_pred_times = []
34+
svc_err = []
35+
36+
train_sizes = [2000, 5000, 10000, 20000, 50000]
37+
38+
bandwidth = 10.0
39+
for train_size in train_sizes:
40+
for name, estimator in [
41+
(
42+
"EigenPro",
43+
EigenProClassifier(
44+
n_epoch=3,
45+
bandwidth=bandwidth,
46+
n_components=30,
47+
subsample_size=1000,
48+
random_state=rng,
49+
),
50+
),
51+
("SupportVector", SVC(C=5, gamma=1.0 / (2 * bandwidth * bandwidth))),
52+
]:
53+
stime = time()
54+
estimator.fit(x_train[:train_size], y_train[:train_size])
55+
fit_t = time() - stime
56+
57+
stime = time()
58+
y_pred_test = estimator.predict(x_test)
59+
pred_t = time() - stime
60+
61+
err = 100.0 * np.sum(y_pred_test != y_test) / len(y_test)
62+
if name == "EigenPro":
63+
eig_fit_times.append(fit_t)
64+
eig_pred_times.append(pred_t)
65+
eig_err.append(err)
66+
else:
67+
svc_fit_times.append(fit_t)
68+
svc_pred_times.append(pred_t)
69+
svc_err.append(err)
70+
print(
71+
"%s Classification with %i training samples in %0.2f seconds."
72+
% (name, train_size, fit_t + pred_t)
73+
)
74+
75+
# set up grid for figures
76+
fig = plt.figure(num=None, figsize=(6, 4), dpi=160)
77+
ax = plt.subplot2grid((2, 2), (0, 0), rowspan=2)
78+
train_size_labels = [str(s) for s in train_sizes]
79+
80+
# Graph fit(train) time
81+
ax.plot(train_sizes, svc_fit_times, "o--", color="g", label="SVC")
82+
ax.plot(train_sizes, eig_fit_times, "o-", color="r", label="FKC (EigenPro)")
83+
ax.set_xscale("log")
84+
ax.set_yscale("log", nonposy="clip")
85+
ax.set_xlabel("train size")
86+
ax.set_ylabel("time (seconds)")
87+
88+
ax.legend()
89+
ax.set_title("Train set")
90+
ax.set_xticks(train_sizes)
91+
ax.set_xticklabels(train_size_labels)
92+
ax.set_xticks([], minor=True)
93+
ax.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
94+
95+
# Graph prediction(test) time
96+
ax = plt.subplot2grid((2, 2), (0, 1), rowspan=1)
97+
ax.plot(train_sizes, eig_pred_times, "o-", color="r")
98+
ax.plot(train_sizes, svc_pred_times, "o--", color="g")
99+
ax.set_xscale("log")
100+
ax.set_yscale("log", nonposy="clip")
101+
ax.set_ylabel("time (seconds)")
102+
ax.set_title("Test set")
103+
ax.set_xticks([])
104+
ax.set_xticks([], minor=True)
105+
106+
# Graph training error
107+
ax = plt.subplot2grid((2, 2), (1, 1), rowspan=1)
108+
ax.plot(train_sizes, eig_err, "o-", color="r")
109+
ax.plot(train_sizes, svc_err, "o-", color="g")
110+
ax.set_xscale("log")
111+
ax.set_xticks(train_sizes)
112+
ax.set_xticklabels(train_size_labels)
113+
ax.set_xticks([], minor=True)
114+
ax.set_xlabel("train size")
115+
ax.set_ylabel("classification error %")
116+
plt.tight_layout()
117+
plt.show()

doc/api.rst

+17
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,23 @@ Kernel approximation
1313

1414
kernel_approximation.Fastfood
1515

16+
EigenPro
17+
========
18+
19+
.. currentmodule:: doc
20+
21+
.. toctree::
22+
modules/eigenpro
23+
24+
.. currentmodule:: sklearn_extra
25+
26+
.. autosummary::
27+
:toctree: generated/
28+
:template: class.rst
29+
30+
kernel_methods.EigenProRegressor
31+
kernel_methods.EigenProClassifier
32+
1633
Clustering
1734
====================
1835

doc/images/eigenpro_mnist.png

194 KB
Loading

doc/images/eigenpro_mnist_noisy.png

205 KB
Loading

doc/images/eigenpro_synthetic.png

191 KB
Loading

0 commit comments

Comments
 (0)