[Breaking] Don't drop trees during DART prediction by default (#5115)
* Simplify DropTrees calling logic * Add `training` parameter for prediction method. * [Breaking]: Add `training` to C API. * Change for R and Python custom objective. * Correct comment. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu> Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/learner.h"
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/gbm/gbtree.h"
|
||||
@@ -18,7 +20,7 @@ TEST(GBTree, SelectTreeMethod) {
|
||||
mparam.num_output_group = 1;
|
||||
|
||||
std::vector<std::shared_ptr<DMatrix> > caches;
|
||||
std::unique_ptr<GradientBooster> p_gbm{
|
||||
std::unique_ptr<GradientBooster> p_gbm {
|
||||
GradientBooster::Create("gbtree", &generic_param, &mparam, caches)};
|
||||
auto& gbtree = dynamic_cast<gbm::GBTree&> (*p_gbm);
|
||||
|
||||
@@ -175,4 +177,41 @@ TEST(Dart, Json_IO) {
|
||||
ASSERT_TRUE(IsA<Object>(model["model"]["gbtree"]));
|
||||
ASSERT_NE(get<Array>(model["model"]["weight_drop"]).size(), 0);
|
||||
}
|
||||
|
||||
TEST(Dart, Prediction) {
|
||||
size_t constexpr kRows = 16, kCols = 10;
|
||||
|
||||
auto pp_dmat = CreateDMatrix(kRows, kCols, 0);
|
||||
auto& p_mat = *pp_dmat;
|
||||
|
||||
std::vector<bst_float> labels (kRows);
|
||||
for (size_t i = 0; i < kRows; ++i) {
|
||||
labels[i] = i % 2;
|
||||
}
|
||||
p_mat->Info().SetInfo("label", labels.data(), DataType::kFloat32, kRows);
|
||||
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create({p_mat}));
|
||||
learner->SetParam("booster", "dart");
|
||||
learner->SetParam("rate_drop", "0.5");
|
||||
learner->Configure();
|
||||
|
||||
for (size_t i = 0; i < 16; ++i) {
|
||||
learner->UpdateOneIter(i, p_mat.get());
|
||||
}
|
||||
|
||||
HostDeviceVector<float> predts_training;
|
||||
learner->Predict(p_mat.get(), false, &predts_training, 0, true);
|
||||
HostDeviceVector<float> predts_inference;
|
||||
learner->Predict(p_mat.get(), false, &predts_inference, 0, false);
|
||||
|
||||
auto& h_predts_training = predts_training.ConstHostVector();
|
||||
auto& h_predts_inference = predts_inference.ConstHostVector();
|
||||
ASSERT_EQ(h_predts_training.size(), h_predts_inference.size());
|
||||
for (size_t i = 0; i < predts_inference.Size(); ++i) {
|
||||
// Inference doesn't drop tree.
|
||||
ASSERT_GT(std::abs(h_predts_training[i] - h_predts_inference[i]), kRtEps);
|
||||
}
|
||||
|
||||
delete pp_dmat;
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -159,7 +159,6 @@ TEST(Learner, Json_ModelIO) {
|
||||
|
||||
{
|
||||
std::unique_ptr<Learner> learner { Learner::Create({p_dmat}) };
|
||||
learner->SetParam("verbosity", "3");
|
||||
for (int32_t iter = 0; iter < kIters; ++iter) {
|
||||
learner->UpdateOneIter(iter, p_dmat.get());
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ TEST(Logging, Basic) {
|
||||
ASSERT_NE(output.find("Test Log Console"), std::string::npos);
|
||||
|
||||
args["silent"] = "False";
|
||||
args["verbosity"] = "1"; // restore
|
||||
args["verbosity"] = "2"; // restore
|
||||
ConsoleLogger::Configure({args.cbegin(), args.cend()});
|
||||
}
|
||||
|
||||
|
||||
@@ -44,7 +44,8 @@ class TestModels(unittest.TestCase):
|
||||
def test_dart(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
param = {'max_depth': 5, 'objective': 'binary:logistic', 'booster': 'dart', 'verbosity': 1}
|
||||
param = {'max_depth': 5, 'objective': 'binary:logistic',
|
||||
'eval_metric': 'logloss', 'booster': 'dart', 'verbosity': 1}
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
num_round = 2
|
||||
@@ -52,7 +53,8 @@ class TestModels(unittest.TestCase):
|
||||
# this is prediction
|
||||
preds = bst.predict(dtest, ntree_limit=num_round)
|
||||
labels = dtest.get_label()
|
||||
err = sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
|
||||
err = sum(1 for i in range(len(preds))
|
||||
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
|
||||
# error must be smaller than 10%
|
||||
assert err < 0.1
|
||||
|
||||
@@ -68,18 +70,31 @@ class TestModels(unittest.TestCase):
|
||||
# assert they are the same
|
||||
assert np.sum(np.abs(preds2 - preds)) == 0
|
||||
|
||||
def my_logloss(preds, dtrain):
|
||||
labels = dtrain.get_label()
|
||||
return 'logloss', np.sum(
|
||||
np.log(np.where(labels, preds, 1 - preds)))
|
||||
|
||||
# check whether custom evaluation metrics work
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist,
|
||||
feval=my_logloss)
|
||||
preds3 = bst.predict(dtest, ntree_limit=num_round)
|
||||
assert all(preds3 == preds)
|
||||
|
||||
# check whether sample_type and normalize_type work
|
||||
num_round = 50
|
||||
param['verbosity'] = 0
|
||||
param['learning_rate'] = 0.1
|
||||
param['rate_drop'] = 0.1
|
||||
preds_list = []
|
||||
for p in [[p0, p1] for p0 in ['uniform', 'weighted'] for p1 in ['tree', 'forest']]:
|
||||
for p in [[p0, p1] for p0 in ['uniform', 'weighted']
|
||||
for p1 in ['tree', 'forest']]:
|
||||
param['sample_type'] = p[0]
|
||||
param['normalize_type'] = p[1]
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist)
|
||||
preds = bst.predict(dtest, ntree_limit=num_round)
|
||||
err = sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
|
||||
err = sum(1 for i in range(len(preds))
|
||||
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
|
||||
assert err < 0.1
|
||||
preds_list.append(preds)
|
||||
|
||||
|
||||
@@ -135,7 +135,7 @@ class TestRanking(unittest.TestCase):
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(self.dtest, 'eval'), (self.dtrain, 'train')]
|
||||
bst = xgboost.train(self.params, self.dtrain, num_boost_round=2500,
|
||||
early_stopping_rounds=10, evals=watchlist)
|
||||
early_stopping_rounds=10, evals=watchlist)
|
||||
assert bst.best_score > 0.98
|
||||
|
||||
def test_cv(self):
|
||||
|
||||
Reference in New Issue
Block a user