@@ -34,19 +34,19 @@ def abserr(predicted, target):
34
34
35
35
36
36
# Predict (probability) based on given parameters
37
- def predict_proba (X , Weights ):
37
+ def predict_prob (X , Weights ):
38
38
Z = af .matmul (X , Weights )
39
39
return af .sigmoid (Z )
40
40
41
41
42
42
# Predict (log probability) based on given parameters
43
- def predict_log_proba (X , Weights ):
44
- return af .log (predict_proba (X , Weights ))
43
+ def predict_log_prob (X , Weights ):
44
+ return af .log (predict_prob (X , Weights ))
45
45
46
46
47
47
# Give most likely class based on given parameters
48
- def predict (X , Weights ):
49
- probs = predict_proba (X , Weights )
48
+ def predict_class (X , Weights ):
49
+ probs = predict_prob (X , Weights )
50
50
_ , classes = af .imax (probs , 1 )
51
51
return classes
52
52
@@ -66,7 +66,7 @@ def cost(Weights, X, Y, lambda_param=1.0):
66
66
lambdat [0 , :] = 0
67
67
68
68
# Get the prediction
69
- H = predict_proba (X , Weights )
69
+ H = predict_prob (X , Weights )
70
70
71
71
# Cost of misprediction
72
72
Jerr = - 1 * af .sum (Y * af .log (H ) + (1 - Y ) * af .log (1 - H ), dim = 0 )
@@ -122,7 +122,7 @@ def benchmark_logistic_regression(train_feats, train_targets, test_feats):
122
122
t0 = time .time ()
123
123
iters = 100
124
124
for i in range (iters ):
125
- test_outputs = predict (test_feats , Weights )
125
+ test_outputs = predict_prob (test_feats , Weights )
126
126
af .eval (test_outputs )
127
127
sync ()
128
128
t1 = time .time ()
@@ -172,8 +172,8 @@ def logit_demo(console, perc):
172
172
af .sync ()
173
173
174
174
# Predict the results
175
- train_outputs = predict_proba (train_feats , Weights )
176
- test_outputs = predict_proba (test_feats , Weights )
175
+ train_outputs = predict_prob (train_feats , Weights )
176
+ test_outputs = predict_prob (test_feats , Weights )
177
177
178
178
print ('Accuracy on training data: {0:2.2f}' .format (accuracy (train_outputs , train_targets )))
179
179
print ('Accuracy on testing data: {0:2.2f}' .format (accuracy (test_outputs , test_targets )))
0 commit comments