Prevent copying data to host. (#4795)
This commit is contained in:
parent
3fa2ceb193
commit
fba298fecb
@ -57,6 +57,14 @@ void GBTree::Configure(const Args& cfg) {
|
||||
|
||||
monitor_.Init("GBTree");
|
||||
|
||||
if (tparam_.tree_method == TreeMethod::kGPUHist &&
|
||||
std::none_of(cfg.cbegin(), cfg.cend(),
|
||||
[](std::pair<std::string, std::string> const& arg) {
|
||||
return arg.first == "predictor";
|
||||
})) {
|
||||
tparam_.predictor = "gpu_predictor";
|
||||
}
|
||||
|
||||
configured_ = true;
|
||||
}
|
||||
|
||||
@ -88,7 +96,7 @@ void GBTree::PerformTreeMethodHeuristic(std::map<std::string, std::string> const
|
||||
// set, since only experts are expected to do so.
|
||||
return;
|
||||
}
|
||||
|
||||
// tparam_ is set before calling this function.
|
||||
if (tparam_.tree_method != TreeMethod::kAuto) {
|
||||
return;
|
||||
}
|
||||
@ -112,7 +120,6 @@ void GBTree::PerformTreeMethodHeuristic(std::map<std::string, std::string> const
|
||||
tparam_.tree_method = TreeMethod::kApprox;
|
||||
} else {
|
||||
tparam_.tree_method = TreeMethod::kExact;
|
||||
tparam_.updater_seq = "grow_colmaker,prune";
|
||||
}
|
||||
LOG(DEBUG) << "Using tree method: " << static_cast<int>(tparam_.tree_method);
|
||||
}
|
||||
@ -131,8 +138,9 @@ void GBTree::ConfigureUpdaters(const std::map<std::string, std::string>& cfg) {
|
||||
/* Choose updaters according to tree_method parameters */
|
||||
switch (tparam_.tree_method) {
|
||||
case TreeMethod::kAuto:
|
||||
// Use heuristic to choose between 'exact' and 'approx'
|
||||
// This choice is deferred to PerformTreeMethodHeuristic().
|
||||
// Use heuristic to choose between 'exact' and 'approx' This
|
||||
// choice is carried out in PerformTreeMethodHeuristic() before
|
||||
// calling this function.
|
||||
break;
|
||||
case TreeMethod::kApprox:
|
||||
tparam_.updater_seq = "grow_histmaker,prune";
|
||||
|
||||
@ -530,7 +530,7 @@ class LearnerImpl : public Learner {
|
||||
void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
|
||||
unsigned ntree_limit = 0) const {
|
||||
CHECK(gbm_ != nullptr)
|
||||
<< "Predict must happen after Load or InitModel";
|
||||
<< "Predict must happen after Load or configuration";
|
||||
this->ValidateDMatrix(data);
|
||||
gbm_->PredictBatch(data, out_preds, ntree_limit);
|
||||
}
|
||||
@ -619,11 +619,11 @@ class LearnerImpl : public Learner {
|
||||
|
||||
void ValidateDMatrix(DMatrix* p_fmat) const {
|
||||
MetaInfo const& info = p_fmat->Info();
|
||||
auto const& weights = info.weights_.HostVector();
|
||||
if (info.group_ptr_.size() != 0 && weights.size() != 0) {
|
||||
CHECK(weights.size() == info.group_ptr_.size() - 1)
|
||||
auto const& weights = info.weights_;
|
||||
if (info.group_ptr_.size() != 0 && weights.Size() != 0) {
|
||||
CHECK(weights.Size() == info.group_ptr_.size() - 1)
|
||||
<< "\n"
|
||||
<< "weights size: " << weights.size() << ", "
|
||||
<< "weights size: " << weights.Size() << ", "
|
||||
<< "groups size: " << info.group_ptr_.size() -1 << ", "
|
||||
<< "num rows: " << p_fmat->Info().num_row_ << "\n"
|
||||
<< "Number of weights should be equal to number of groups in ranking task.";
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user