Skip to content

Commit f2f0ea5

Browse files
author
Mark Poscablo
committed
Corrected predict method names
1 parent f970d7c commit f2f0ea5

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

Diff for: examples/machine_learning/logistic_regression.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,19 @@ def abserr(predicted, target):
3434

3535

3636
# Predict (probability) based on given parameters
37-
def predict_proba(X, Weights):
37+
def predict_prob(X, Weights):
3838
Z = af.matmul(X, Weights)
3939
return af.sigmoid(Z)
4040

4141

4242
# Predict (log probability) based on given parameters
43-
def predict_log_proba(X, Weights):
44-
return af.log(predict_proba(X, Weights))
43+
def predict_log_prob(X, Weights):
44+
return af.log(predict_prob(X, Weights))
4545

4646

4747
# Give most likely class based on given parameters
48-
def predict(X, Weights):
49-
probs = predict_proba(X, Weights)
48+
def predict_class(X, Weights):
49+
probs = predict_prob(X, Weights)
5050
_, classes = af.imax(probs, 1)
5151
return classes
5252

@@ -66,7 +66,7 @@ def cost(Weights, X, Y, lambda_param=1.0):
6666
lambdat[0, :] = 0
6767

6868
# Get the prediction
69-
H = predict_proba(X, Weights)
69+
H = predict_prob(X, Weights)
7070

7171
# Cost of misprediction
7272
Jerr = -1 * af.sum(Y * af.log(H) + (1 - Y) * af.log(1 - H), dim=0)
@@ -122,7 +122,7 @@ def benchmark_logistic_regression(train_feats, train_targets, test_feats):
122122
t0 = time.time()
123123
iters = 100
124124
for i in range(iters):
125-
test_outputs = predict(test_feats, Weights)
125+
test_outputs = predict_prob(test_feats, Weights)
126126
af.eval(test_outputs)
127127
sync()
128128
t1 = time.time()
@@ -172,8 +172,8 @@ def logit_demo(console, perc):
172172
af.sync()
173173

174174
# Predict the results
175-
train_outputs = predict_proba(train_feats, Weights)
176-
test_outputs = predict_proba(test_feats, Weights)
175+
train_outputs = predict_prob(train_feats, Weights)
176+
test_outputs = predict_prob(test_feats, Weights)
177177

178178
print('Accuracy on training data: {0:2.2f}'.format(accuracy(train_outputs, train_targets)))
179179
print('Accuracy on testing data: {0:2.2f}'.format(accuracy(test_outputs, test_targets)))

0 commit comments

Comments
 (0)