fix bug for demo/multiclass_classification/train.py (#2747)
This commit is contained in:
parent
d570337262
commit
178517524f
@ -47,5 +47,5 @@ bst = xgb.train(param, xg_train, num_round, watchlist)
|
|||||||
# get prediction, this is in 1D array, need reshape to (ndata, nclass)
|
# get prediction, this is in 1D array, need reshape to (ndata, nclass)
|
||||||
pred_prob = bst.predict(xg_test).reshape(test_Y.shape[0], 6)
|
pred_prob = bst.predict(xg_test).reshape(test_Y.shape[0], 6)
|
||||||
pred_label = np.argmax(pred_prob, axis=1)
|
pred_label = np.argmax(pred_prob, axis=1)
|
||||||
error_rate = np.sum(pred != test_Y) / test_Y.shape[0]
|
error_rate = np.sum(pred_label != test_Y) / test_Y.shape[0]
|
||||||
print('Test error using softprob = {}'.format(error_rate))
|
print('Test error using softprob = {}'.format(error_rate))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user