fix bug for demo/multiclass_classification/train.py (#2747)
This commit is contained in:
parent
d570337262
commit
178517524f
@ -47,5 +47,5 @@ bst = xgb.train(param, xg_train, num_round, watchlist)
|
||||
# get prediction, this is in 1D array, need reshape to (ndata, nclass)
|
||||
pred_prob = bst.predict(xg_test).reshape(test_Y.shape[0], 6)
|
||||
pred_label = np.argmax(pred_prob, axis=1)
|
||||
error_rate = np.sum(pred != test_Y) / test_Y.shape[0]
|
||||
error_rate = np.sum(pred_label != test_Y) / test_Y.shape[0]
|
||||
print('Test error using softprob = {}'.format(error_rate))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user