Skip to content

Commit dfea6f3

Browse files
janithwannicclauss
authored andcommitted
✅ added tests for Perceptron in Neural Networks (#1506)
* ✅ added tests for Perceptron in Neural Networks * Space * Format code with psf/black
1 parent 1ed47ad commit dfea6f3

File tree

1 file changed

+88
-25
lines changed

1 file changed

+88
-25
lines changed

neural_network/perceptron.py

+88-25
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,53 @@
11
"""
2-
32
Perceptron
43
w = w + N * (d(k) - y) * x(k)
54
6-
Using perceptron network for oil analysis,
7-
with Measuring of 3 parameters that represent chemical characteristics we can classify the oil, in p1 or p2
5+
Using perceptron network for oil analysis, with Measuring of 3 parameters
6+
that represent chemical characteristics we can classify the oil, in p1 or p2
87
p1 = -1
98
p2 = 1
10-
119
"""
1210
import random
1311

1412

1513
class Perceptron:
16-
def __init__(self, sample, exit, learn_rate=0.01, epoch_number=1000, bias=-1):
14+
def __init__(self, sample, target, learning_rate=0.01, epoch_number=1000, bias=-1):
15+
"""
16+
Initializes a Perceptron network for oil analysis
17+
:param sample: sample dataset of 3 parameters with shape [30,3]
18+
:param target: variable for classification with two possible states -1 or 1
19+
:param learning_rate: learning rate used in optimizing.
20+
:param epoch_number: number of epochs to train network on.
21+
:param bias: bias value for the network.
22+
"""
1723
self.sample = sample
18-
self.exit = exit
19-
self.learn_rate = learn_rate
24+
if len(self.sample) == 0:
25+
raise AttributeError("Sample data can not be empty")
26+
self.target = target
27+
if len(self.target) == 0:
28+
raise AttributeError("Target data can not be empty")
29+
if len(self.sample) != len(self.target):
30+
raise AttributeError(
31+
"Sample data and Target data do not have matching lengths"
32+
)
33+
self.learning_rate = learning_rate
2034
self.epoch_number = epoch_number
2135
self.bias = bias
2236
self.number_sample = len(sample)
23-
self.col_sample = len(sample[0])
37+
self.col_sample = len(sample[0]) # number of columns in dataset
2438
self.weight = []
2539

26-
def training(self):
40+
def training(self) -> None:
41+
"""
42+
Trains perceptron for epochs <= given number of epochs
43+
:return: None
44+
>>> data = [[2.0149, 0.6192, 10.9263]]
45+
>>> targets = [-1]
46+
>>> perceptron = Perceptron(data,targets)
47+
>>> perceptron.training() # doctest: +ELLIPSIS
48+
('\\nEpoch:\\n', ...)
49+
...
50+
"""
2751
for sample in self.sample:
2852
sample.insert(0, self.bias)
2953

@@ -35,31 +59,47 @@ def training(self):
3559
epoch_count = 0
3660

3761
while True:
38-
erro = False
62+
has_misclassified = False
3963
for i in range(self.number_sample):
4064
u = 0
4165
for j in range(self.col_sample + 1):
4266
u = u + self.weight[j] * self.sample[i][j]
4367
y = self.sign(u)
44-
if y != self.exit[i]:
45-
68+
if y != self.target[i]:
4669
for j in range(self.col_sample + 1):
47-
4870
self.weight[j] = (
4971
self.weight[j]
50-
+ self.learn_rate * (self.exit[i] - y) * self.sample[i][j]
72+
+ self.learning_rate
73+
* (self.target[i] - y)
74+
* self.sample[i][j]
5175
)
52-
erro = True
76+
has_misclassified = True
5377
# print('Epoch: \n',epoch_count)
5478
epoch_count = epoch_count + 1
5579
# if you want controle the epoch or just by erro
56-
if erro == False:
80+
if not has_misclassified:
5781
print(("\nEpoch:\n", epoch_count))
5882
print("------------------------\n")
5983
# if epoch_count > self.epoch_number or not erro:
6084
break
6185

62-
def sort(self, sample):
86+
def sort(self, sample) -> None:
87+
"""
88+
:param sample: example row to classify as P1 or P2
89+
:return: None
90+
>>> data = [[2.0149, 0.6192, 10.9263]]
91+
>>> targets = [-1]
92+
>>> perceptron = Perceptron(data,targets)
93+
>>> perceptron.training() # doctest:+ELLIPSIS
94+
('\\nEpoch:\\n', ...)
95+
...
96+
>>> perceptron.sort([-0.6508, 0.1097, 4.0009]) # doctest: +ELLIPSIS
97+
('Sample: ', ...)
98+
classification: P1
99+
100+
"""
101+
if len(self.sample) == 0:
102+
raise AttributeError("Sample data can not be empty")
63103
sample.insert(0, self.bias)
64104
u = 0
65105
for i in range(self.col_sample + 1):
@@ -74,7 +114,21 @@ def sort(self, sample):
74114
print(("Sample: ", sample))
75115
print("classification: P2")
76116

77-
def sign(self, u):
117+
def sign(self, u: float) -> int:
118+
"""
119+
threshold function for classification
120+
:param u: input number
121+
:return: 1 if the input is greater than 0, otherwise -1
122+
>>> data = [[0],[-0.5],[0.5]]
123+
>>> targets = [1,-1,1]
124+
>>> perceptron = Perceptron(data,targets)
125+
>>> perceptron.sign(0)
126+
1
127+
>>> perceptron.sign(-0.5)
128+
-1
129+
>>> perceptron.sign(0.5)
130+
1
131+
"""
78132
return 1 if u >= 0 else -1
79133

80134

@@ -144,15 +198,24 @@ def sign(self, u):
144198
1,
145199
]
146200

147-
network = Perceptron(
148-
sample=samples, exit=exit, learn_rate=0.01, epoch_number=1000, bias=-1
149-
)
150-
151-
network.training()
152201

153202
if __name__ == "__main__":
203+
import doctest
204+
205+
doctest.testmod()
206+
207+
network = Perceptron(
208+
sample=samples, target=exit, learning_rate=0.01, epoch_number=1000, bias=-1
209+
)
210+
network.training()
211+
print("Finished training perceptron")
212+
print("Enter values to predict or q to exit")
154213
while True:
155214
sample = []
156-
for i in range(3):
157-
sample.insert(i, float(input("value: ")))
215+
for i in range(len(samples[0])):
216+
observation = input("value: ").strip()
217+
if observation == "q":
218+
break
219+
observation = float(observation)
220+
sample.insert(i, observation)
158221
network.sort(sample)

0 commit comments

Comments
 (0)