Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit956a227

Browse files
Mark Poscablo9prady9
Mark Poscablo
authored andcommitted
Corrected predict method names
1 parentfabe335 commit956a227

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

‎examples/machine_learning/logistic_regression.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,19 @@ def abserr(predicted, target):
3434

3535

3636
# Predict (probability) based on given parameters
37-
defpredict_proba(X,Weights):
37+
defpredict_prob(X,Weights):
3838
Z=af.matmul(X,Weights)
3939
returnaf.sigmoid(Z)
4040

4141

4242
# Predict (log probability) based on given parameters
43-
defpredict_log_proba(X,Weights):
44-
returnaf.log(predict_proba(X,Weights))
43+
defpredict_log_prob(X,Weights):
44+
returnaf.log(predict_prob(X,Weights))
4545

4646

4747
# Give most likely class based on given parameters
48-
defpredict(X,Weights):
49-
probs=predict_proba(X,Weights)
48+
defpredict_class(X,Weights):
49+
probs=predict_prob(X,Weights)
5050
_,classes=af.imax(probs,1)
5151
returnclasses
5252

@@ -66,7 +66,7 @@ def cost(Weights, X, Y, lambda_param=1.0):
6666
lambdat[0, :]=0
6767

6868
# Get the prediction
69-
H=predict_proba(X,Weights)
69+
H=predict_prob(X,Weights)
7070

7171
# Cost of misprediction
7272
Jerr=-1*af.sum(Y*af.log(H)+ (1-Y)*af.log(1-H),dim=0)
@@ -122,7 +122,7 @@ def benchmark_logistic_regression(train_feats, train_targets, test_feats):
122122
t0=time.time()
123123
iters=100
124124
foriinrange(iters):
125-
test_outputs=predict(test_feats,Weights)
125+
test_outputs=predict_prob(test_feats,Weights)
126126
af.eval(test_outputs)
127127
sync()
128128
t1=time.time()
@@ -172,8 +172,8 @@ def logit_demo(console, perc):
172172
af.sync()
173173

174174
# Predict the results
175-
train_outputs=predict_proba(train_feats,Weights)
176-
test_outputs=predict_proba(test_feats,Weights)
175+
train_outputs=predict_prob(train_feats,Weights)
176+
test_outputs=predict_prob(test_feats,Weights)
177177

178178
print('Accuracy on training data: {0:2.2f}'.format(accuracy(train_outputs,train_targets)))
179179
print('Accuracy on testing data: {0:2.2f}'.format(accuracy(test_outputs,test_targets)))

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp