|
|
|
|
@@ -34,6 +34,7 @@ DMLC_REGISTRY_FILE_TAG(gbtree);
|
|
|
|
|
|
|
|
|
|
void GBTree::Configure(const Args& cfg) {
|
|
|
|
|
this->cfg_ = cfg;
|
|
|
|
|
std::string updater_seq = tparam_.updater_seq;
|
|
|
|
|
tparam_.UpdateAllowUnknown(cfg);
|
|
|
|
|
|
|
|
|
|
model_.Configure(cfg);
|
|
|
|
|
@@ -75,24 +76,31 @@ void GBTree::Configure(const Args& cfg) {
|
|
|
|
|
"`tree_method` parameter instead.";
|
|
|
|
|
// Don't drive users to silent XGBOost.
|
|
|
|
|
showed_updater_warning_ = true;
|
|
|
|
|
} else {
|
|
|
|
|
this->ConfigureUpdaters();
|
|
|
|
|
LOG(DEBUG) << "Using updaters: " << tparam_.updater_seq;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto& up : updaters_) {
|
|
|
|
|
up->Configure(cfg);
|
|
|
|
|
this->ConfigureUpdaters();
|
|
|
|
|
if (updater_seq != tparam_.updater_seq) {
|
|
|
|
|
updaters_.clear();
|
|
|
|
|
this->InitUpdater(cfg);
|
|
|
|
|
} else {
|
|
|
|
|
for (auto &up : updaters_) {
|
|
|
|
|
up->Configure(cfg);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
configured_ = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// FIXME(trivialfis): This handles updaters and predictor. Because the choice of updaters
|
|
|
|
|
// depends on whether external memory is used and how large is dataset. We can remove the
|
|
|
|
|
// dependency on DMatrix once `hist` tree method can handle external memory so that we can
|
|
|
|
|
// make it default.
|
|
|
|
|
// FIXME(trivialfis): This handles updaters. Because the choice of updaters depends on
|
|
|
|
|
// whether external memory is used and how large is dataset. We can remove the dependency
|
|
|
|
|
// on DMatrix once `hist` tree method can handle external memory so that we can make it
|
|
|
|
|
// default.
|
|
|
|
|
void GBTree::ConfigureWithKnownData(Args const& cfg, DMatrix* fmat) {
|
|
|
|
|
CHECK(this->configured_);
|
|
|
|
|
std::string updater_seq = tparam_.updater_seq;
|
|
|
|
|
CHECK(tparam_.GetInitialised());
|
|
|
|
|
|
|
|
|
|
tparam_.UpdateAllowUnknown(cfg);
|
|
|
|
|
|
|
|
|
|
this->PerformTreeMethodHeuristic(fmat);
|
|
|
|
|
this->ConfigureUpdaters();
|
|
|
|
|
@@ -101,9 +109,8 @@ void GBTree::ConfigureWithKnownData(Args const& cfg, DMatrix* fmat) {
|
|
|
|
|
if (updater_seq != tparam_.updater_seq) {
|
|
|
|
|
LOG(DEBUG) << "Using updaters: " << tparam_.updater_seq;
|
|
|
|
|
this->updaters_.clear();
|
|
|
|
|
this->InitUpdater(cfg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
this->InitUpdater(cfg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) {
|
|
|
|
|
@@ -141,6 +148,9 @@ void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GBTree::ConfigureUpdaters() {
|
|
|
|
|
if (specified_updater_) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
// `updater` parameter was manually specified
|
|
|
|
|
/* Choose updaters according to tree_method parameters */
|
|
|
|
|
switch (tparam_.tree_method) {
|
|
|
|
|
@@ -289,6 +299,46 @@ void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& ne
|
|
|
|
|
monitor_.Stop("CommitModel");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GBTree::LoadConfig(Json const& in) {
|
|
|
|
|
CHECK_EQ(get<String>(in["name"]), "gbtree");
|
|
|
|
|
fromJson(in["gbtree_train_param"], &tparam_);
|
|
|
|
|
int32_t const n_gpus = xgboost::common::AllVisibleGPUs();
|
|
|
|
|
if (n_gpus == 0 && tparam_.predictor == PredictorType::kGPUPredictor) {
|
|
|
|
|
tparam_.UpdateAllowUnknown(Args{{"predictor", "auto"}});
|
|
|
|
|
}
|
|
|
|
|
if (n_gpus == 0 && tparam_.tree_method == TreeMethod::kGPUHist) {
|
|
|
|
|
tparam_.UpdateAllowUnknown(Args{{"tree_method", "hist"}});
|
|
|
|
|
LOG(WARNING)
|
|
|
|
|
<< "Loading from a raw memory buffer on CPU only machine. "
|
|
|
|
|
"Change tree_method to hist.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto const& j_updaters = get<Object const>(in["updater"]);
|
|
|
|
|
updaters_.clear();
|
|
|
|
|
for (auto const& kv : j_updaters) {
|
|
|
|
|
std::unique_ptr<TreeUpdater> up(TreeUpdater::Create(kv.first, generic_param_));
|
|
|
|
|
up->LoadConfig(kv.second);
|
|
|
|
|
updaters_.push_back(std::move(up));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
specified_updater_ = get<Boolean>(in["specified_updater"]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GBTree::SaveConfig(Json* p_out) const {
|
|
|
|
|
auto& out = *p_out;
|
|
|
|
|
out["name"] = String("gbtree");
|
|
|
|
|
out["gbtree_train_param"] = toJson(tparam_);
|
|
|
|
|
out["updater"] = Object();
|
|
|
|
|
|
|
|
|
|
auto& j_updaters = out["updater"];
|
|
|
|
|
for (auto const& up : updaters_) {
|
|
|
|
|
j_updaters[up->Name()] = Object();
|
|
|
|
|
auto& j_up = j_updaters[up->Name()];
|
|
|
|
|
up->SaveConfig(&j_up);
|
|
|
|
|
}
|
|
|
|
|
out["specified_updater"] = Boolean{specified_updater_};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GBTree::LoadModel(Json const& in) {
|
|
|
|
|
CHECK_EQ(get<String>(in["name"]), "gbtree");
|
|
|
|
|
model_.LoadModel(in["model"]);
|
|
|
|
|
@@ -324,7 +374,7 @@ class Dart : public GBTree {
|
|
|
|
|
for (size_t i = 0; i < weight_drop_.size(); ++i) {
|
|
|
|
|
j_weight_drop[i] = Number(weight_drop_[i]);
|
|
|
|
|
}
|
|
|
|
|
out["weight_drop"] = Array(j_weight_drop);
|
|
|
|
|
out["weight_drop"] = Array(std::move(j_weight_drop));
|
|
|
|
|
}
|
|
|
|
|
void LoadModel(Json const& in) override {
|
|
|
|
|
CHECK_EQ(get<String>(in["name"]), "dart");
|
|
|
|
|
@@ -352,6 +402,21 @@ class Dart : public GBTree {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void LoadConfig(Json const& in) override {
|
|
|
|
|
CHECK_EQ(get<String>(in["name"]), "dart");
|
|
|
|
|
auto const& gbtree = in["gbtree"];
|
|
|
|
|
GBTree::LoadConfig(gbtree);
|
|
|
|
|
fromJson(in["dart_train_param"], &dparam_);
|
|
|
|
|
}
|
|
|
|
|
void SaveConfig(Json* p_out) const override {
|
|
|
|
|
auto& out = *p_out;
|
|
|
|
|
out["name"] = String("dart");
|
|
|
|
|
out["gbtree"] = Object();
|
|
|
|
|
auto& gbtree = out["gbtree"];
|
|
|
|
|
GBTree::SaveConfig(&gbtree);
|
|
|
|
|
out["dart_train_param"] = toJson(dparam_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// predict the leaf scores with dropout if ntree_limit = 0
|
|
|
|
|
void PredictBatch(DMatrix* p_fmat,
|
|
|
|
|
HostDeviceVector<bst_float>* out_preds,
|
|
|
|
|
|