diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 7a911a4b8..43048061e 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -57,6 +57,14 @@ void GBTree::Configure(const Args& cfg) { monitor_.Init("GBTree"); + if (tparam_.tree_method == TreeMethod::kGPUHist && + std::none_of(cfg.cbegin(), cfg.cend(), + [](std::pair const& arg) { + return arg.first == "predictor"; + })) { + tparam_.predictor = "gpu_predictor"; + } + configured_ = true; } @@ -88,7 +96,7 @@ void GBTree::PerformTreeMethodHeuristic(std::map const // set, since only experts are expected to do so. return; } - + // tparam_ is set before calling this function. if (tparam_.tree_method != TreeMethod::kAuto) { return; } @@ -112,7 +120,6 @@ void GBTree::PerformTreeMethodHeuristic(std::map const tparam_.tree_method = TreeMethod::kApprox; } else { tparam_.tree_method = TreeMethod::kExact; - tparam_.updater_seq = "grow_colmaker,prune"; } LOG(DEBUG) << "Using tree method: " << static_cast(tparam_.tree_method); } @@ -131,8 +138,9 @@ void GBTree::ConfigureUpdaters(const std::map& cfg) { /* Choose updaters according to tree_method parameters */ switch (tparam_.tree_method) { case TreeMethod::kAuto: - // Use heuristic to choose between 'exact' and 'approx' - // This choice is deferred to PerformTreeMethodHeuristic(). + // Use heuristic to choose between 'exact' and 'approx' This + // choice is carried out in PerformTreeMethodHeuristic() before + // calling this function. break; case TreeMethod::kApprox: tparam_.updater_seq = "grow_histmaker,prune"; diff --git a/src/learner.cc b/src/learner.cc index 8e694882d..0eb175f3d 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -530,7 +530,7 @@ class LearnerImpl : public Learner { void PredictRaw(DMatrix* data, HostDeviceVector* out_preds, unsigned ntree_limit = 0) const { CHECK(gbm_ != nullptr) - << "Predict must happen after Load or InitModel"; + << "Predict must happen after Load or configuration"; this->ValidateDMatrix(data); gbm_->PredictBatch(data, out_preds, ntree_limit); } @@ -619,11 +619,11 @@ class LearnerImpl : public Learner { void ValidateDMatrix(DMatrix* p_fmat) const { MetaInfo const& info = p_fmat->Info(); - auto const& weights = info.weights_.HostVector(); - if (info.group_ptr_.size() != 0 && weights.size() != 0) { - CHECK(weights.size() == info.group_ptr_.size() - 1) + auto const& weights = info.weights_; + if (info.group_ptr_.size() != 0 && weights.Size() != 0) { + CHECK(weights.Size() == info.group_ptr_.size() - 1) << "\n" - << "weights size: " << weights.size() << ", " + << "weights size: " << weights.Size() << ", " << "groups size: " << info.group_ptr_.size() -1 << ", " << "num rows: " << p_fmat->Info().num_row_ << "\n" << "Number of weights should be equal to number of groups in ranking task.";