fix bug for demo/multiclass_classification/train.py (#2747)

This commit is contained in:
zhxfl 2017-09-26 11:37:21 +08:00 committed by Vadim Khotilovich
parent d570337262
commit 178517524f

View File

@ -47,5 +47,5 @@ bst = xgb.train(param, xg_train, num_round, watchlist)
# get prediction, this is in 1D array, need reshape to (ndata, nclass) # get prediction, this is in 1D array, need reshape to (ndata, nclass)
pred_prob = bst.predict(xg_test).reshape(test_Y.shape[0], 6) pred_prob = bst.predict(xg_test).reshape(test_Y.shape[0], 6)
pred_label = np.argmax(pred_prob, axis=1) pred_label = np.argmax(pred_prob, axis=1)
error_rate = np.sum(pred != test_Y) / test_Y.shape[0] error_rate = np.sum(pred_label != test_Y) / test_Y.shape[0]
print('Test error using softprob = {}'.format(error_rate)) print('Test error using softprob = {}'.format(error_rate))