Skip to content

Commit d634ffa

Browse files
committed
Move Jupyter Notebooks from Python repo
1 parent f065fc1 commit d634ffa

14 files changed

+9823
-0
lines changed

machine_learning/dbscan/dbscan.ipynb

+376
Large diffs are not rendered by default.

machine_learning/dbscan/dbscan.py

+271
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
from sklearn.datasets import make_moons
4+
import warnings
5+
6+
7+
def euclidean_distance(q, p):
8+
"""
9+
Calculates the Euclidean distance
10+
between points q and p
11+
12+
Distance can only be calculated between numeric values
13+
>>> euclidean_distance([1,'a'],[1,2])
14+
Traceback (most recent call last):
15+
...
16+
ValueError: Non-numeric input detected
17+
18+
The dimentions of both the points must be the same
19+
>>> euclidean_distance([1,1,1],[1,2])
20+
Traceback (most recent call last):
21+
...
22+
ValueError: expected dimensions to be 2-d, instead got p:3 and q:2
23+
24+
Supports only two dimentional points
25+
>>> euclidean_distance([1,1,1],[1,2])
26+
Traceback (most recent call last):
27+
...
28+
ValueError: expected dimensions to be 2-d, instead got p:3 and q:2
29+
30+
Input should be in the format [x,y] or (x,y)
31+
>>> euclidean_distance(1,2)
32+
Traceback (most recent call last):
33+
...
34+
TypeError: inputs must be iterable, either list [x,y] or tuple (x,y)
35+
"""
36+
if not hasattr(q, "__iter__") or not hasattr(p, "__iter__"):
37+
raise TypeError("inputs must be iterable, either list [x,y] or tuple (x,y)")
38+
39+
if isinstance(q, str) or isinstance(p, str):
40+
raise TypeError("inputs cannot be str")
41+
42+
if len(q) != 2 or len(p) != 2:
43+
raise ValueError(
44+
"expected dimensions to be 2-d, instead got p:{} and q:{}".format(
45+
len(q), len(p)
46+
)
47+
)
48+
49+
for num in q + p:
50+
try:
51+
num = int(num)
52+
except:
53+
raise ValueError("Non-numeric input detected")
54+
55+
a = pow((q[0] - p[0]), 2)
56+
b = pow((q[1] - p[1]), 2)
57+
return pow((a + b), 0.5)
58+
59+
60+
def find_neighbors(db, q, eps):
61+
"""
62+
Finds all points in the db that
63+
are within a distance of eps from Q
64+
65+
eps value should be a number
66+
>>> find_neighbors({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}}, (2,5),'a')
67+
Traceback (most recent call last):
68+
...
69+
ValueError: eps should be either int or float
70+
71+
Q must be a 2-d point as list or tuple
72+
>>> find_neighbors({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}}, 2, 0.5)
73+
Traceback (most recent call last):
74+
...
75+
TypeError: Q must a 2-dimentional point in the format (x,y) or [x,y]
76+
77+
Points must be in correct format
78+
>>> find_neighbors([], (2,2) ,0.4)
79+
Traceback (most recent call last):
80+
...
81+
TypeError: db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}}
82+
"""
83+
84+
if not isinstance(eps, (int, float)):
85+
raise ValueError("eps should be either int or float")
86+
87+
if not hasattr(q, "__iter__"):
88+
raise TypeError("Q must a 2-dimentional point in the format (x,y) or [x,y]")
89+
90+
if not isinstance(db, dict):
91+
raise TypeError(
92+
"db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}}"
93+
)
94+
95+
return [p for p in db if euclidean_distance(q, p) <= eps]
96+
97+
98+
def plot_cluster(db, clusters, ax):
99+
"""
100+
Extracts all the points in the db and puts them together
101+
as seperate clusters and finally plots them
102+
103+
db cannot be empty
104+
>>> fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5))
105+
>>> plot_cluster({},[1,2], axes[1] )
106+
Traceback (most recent call last):
107+
...
108+
Exception: db is empty. No points to cluster
109+
110+
clusters cannot be empty
111+
>>> fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5))
112+
>>> plot_cluster({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},[],axes[1] )
113+
Traceback (most recent call last):
114+
...
115+
Exception: nothing to cluster. Empty clusters
116+
117+
clusters cannot be empty
118+
>>> fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5))
119+
>>> plot_cluster({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},[],axes[1] )
120+
Traceback (most recent call last):
121+
...
122+
Exception: nothing to cluster. Empty clusters
123+
124+
ax must be a plotable
125+
>>> plot_cluster({ (1,2):{'label':'1'}, (2,3):{'label':'2'}},[1,2], [] )
126+
Traceback (most recent call last):
127+
...
128+
TypeError: ax must be an slot in a matplotlib figure
129+
"""
130+
if len(db) == 0:
131+
raise Exception("db is empty. No points to cluster")
132+
133+
if len(clusters) == 0:
134+
raise Exception("nothing to cluster. Empty clusters")
135+
136+
if not hasattr(ax, "plot"):
137+
raise TypeError("ax must be an slot in a matplotlib figure")
138+
139+
temp = []
140+
noise = []
141+
for i in clusters:
142+
stack = []
143+
for k, v in db.items():
144+
if v["label"] == i:
145+
stack.append(k)
146+
elif v["label"] == "noise":
147+
noise.append(k)
148+
temp.append(stack)
149+
150+
color = iter(plt.cm.rainbow(np.linspace(0, 1, len(clusters))))
151+
for i in range(0, len(temp)):
152+
c = next(color)
153+
x = [l[0] for l in temp[i]]
154+
y = [l[1] for l in temp[i]]
155+
ax.plot(x, y, "ro", c=c)
156+
157+
x = [l[0] for l in noise]
158+
y = [l[1] for l in noise]
159+
ax.plot(x, y, "ro", c="0")
160+
161+
162+
def dbscan(db, eps, min_pts):
163+
"""
164+
Implementation of the DBSCAN algorithm
165+
166+
Points must be in correct format
167+
>>> dbscan([], (2,2) ,0.4)
168+
Traceback (most recent call last):
169+
...
170+
TypeError: db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}}
171+
172+
eps value should be a number
173+
>>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},'a',20 )
174+
Traceback (most recent call last):
175+
...
176+
ValueError: eps should be either int or float
177+
178+
min_pts value should be an integer
179+
>>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},0.4,20.0 )
180+
Traceback (most recent call last):
181+
...
182+
ValueError: min_pts should be int
183+
184+
db cannot be empty
185+
>>> dbscan({},0.4,20.0 )
186+
Traceback (most recent call last):
187+
...
188+
Exception: db is empty, nothing to cluster
189+
190+
min_pts cannot be negative
191+
>>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}}, 0.4, -20)
192+
Traceback (most recent call last):
193+
...
194+
ValueError: min_pts or eps cannot be negative
195+
196+
eps cannot be negative
197+
>>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},-0.4, 20)
198+
Traceback (most recent call last):
199+
...
200+
ValueError: min_pts or eps cannot be negative
201+
202+
"""
203+
if not isinstance(db, dict):
204+
raise TypeError(
205+
"db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}}"
206+
)
207+
208+
if len(db) == 0:
209+
raise Exception("db is empty, nothing to cluster")
210+
211+
if not isinstance(eps, (int, float)):
212+
raise ValueError("eps should be either int or float")
213+
214+
if not isinstance(min_pts, int):
215+
raise ValueError("min_pts should be int")
216+
217+
if min_pts < 0 or eps < 0:
218+
raise ValueError("min_pts or eps cannot be negative")
219+
220+
if min_pts == 0:
221+
warnings.warn("min_pts is 0. Are you sure you want this ?")
222+
223+
if eps == 0:
224+
warnings.warn("eps is 0. Are you sure you want this ?")
225+
226+
clusters = []
227+
c = 0
228+
for p in db:
229+
if db[p]["label"] != "undefined":
230+
continue
231+
neighbors = find_neighbors(db, p, eps)
232+
if len(neighbors) < min_pts:
233+
db[p]["label"] = "noise"
234+
continue
235+
c += 1
236+
clusters.append(c)
237+
db[p]["label"] = c
238+
neighbors.remove(p)
239+
seed_set = neighbors.copy()
240+
while seed_set != []:
241+
q = seed_set.pop(0)
242+
if db[q]["label"] == "noise":
243+
db[q]["label"] = c
244+
if db[q]["label"] != "undefined":
245+
continue
246+
db[q]["label"] = c
247+
neighbors_n = find_neighbors(db, q, eps)
248+
if len(neighbors_n) >= min_pts:
249+
seed_set = seed_set + neighbors_n
250+
return db, clusters
251+
252+
253+
if __name__ == "__main__":
254+
255+
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5))
256+
257+
x, label = make_moons(n_samples=200, noise=0.1, random_state=19)
258+
259+
axes[0].plot(x[:, 0], x[:, 1], "ro")
260+
261+
points = {(point[0], point[1]): {"label": "undefined"} for point in x}
262+
263+
eps = 0.25
264+
265+
min_pts = 12
266+
267+
db, clusters = dbscan(points, eps, min_pts)
268+
269+
plot_cluster(db, clusters, axes[1])
270+
271+
plt.show()

0 commit comments

Comments
 (0)