[Breaking] Don't drop trees during DART prediction by default (#5115)
* Simplify DropTrees calling logic * Add `training` parameter for prediction method. * [Breaking]: Add `training` to C API. * Change for R and Python custom objective. * Correct comment. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu> Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/learner.h"
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/gbm/gbtree.h"
|
||||
@@ -18,7 +20,7 @@ TEST(GBTree, SelectTreeMethod) {
|
||||
mparam.num_output_group = 1;
|
||||
|
||||
std::vector<std::shared_ptr<DMatrix> > caches;
|
||||
std::unique_ptr<GradientBooster> p_gbm{
|
||||
std::unique_ptr<GradientBooster> p_gbm {
|
||||
GradientBooster::Create("gbtree", &generic_param, &mparam, caches)};
|
||||
auto& gbtree = dynamic_cast<gbm::GBTree&> (*p_gbm);
|
||||
|
||||
@@ -175,4 +177,41 @@ TEST(Dart, Json_IO) {
|
||||
ASSERT_TRUE(IsA<Object>(model["model"]["gbtree"]));
|
||||
ASSERT_NE(get<Array>(model["model"]["weight_drop"]).size(), 0);
|
||||
}
|
||||
|
||||
TEST(Dart, Prediction) {
|
||||
size_t constexpr kRows = 16, kCols = 10;
|
||||
|
||||
auto pp_dmat = CreateDMatrix(kRows, kCols, 0);
|
||||
auto& p_mat = *pp_dmat;
|
||||
|
||||
std::vector<bst_float> labels (kRows);
|
||||
for (size_t i = 0; i < kRows; ++i) {
|
||||
labels[i] = i % 2;
|
||||
}
|
||||
p_mat->Info().SetInfo("label", labels.data(), DataType::kFloat32, kRows);
|
||||
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create({p_mat}));
|
||||
learner->SetParam("booster", "dart");
|
||||
learner->SetParam("rate_drop", "0.5");
|
||||
learner->Configure();
|
||||
|
||||
for (size_t i = 0; i < 16; ++i) {
|
||||
learner->UpdateOneIter(i, p_mat.get());
|
||||
}
|
||||
|
||||
HostDeviceVector<float> predts_training;
|
||||
learner->Predict(p_mat.get(), false, &predts_training, 0, true);
|
||||
HostDeviceVector<float> predts_inference;
|
||||
learner->Predict(p_mat.get(), false, &predts_inference, 0, false);
|
||||
|
||||
auto& h_predts_training = predts_training.ConstHostVector();
|
||||
auto& h_predts_inference = predts_inference.ConstHostVector();
|
||||
ASSERT_EQ(h_predts_training.size(), h_predts_inference.size());
|
||||
for (size_t i = 0; i < predts_inference.Size(); ++i) {
|
||||
// Inference doesn't drop tree.
|
||||
ASSERT_GT(std::abs(h_predts_training[i] - h_predts_inference[i]), kRtEps);
|
||||
}
|
||||
|
||||
delete pp_dmat;
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user