Prevent copying data to host. (#4795)

This commit is contained in:
Jiaming Yuan
2019-08-20 23:06:27 -04:00
committed by GitHub
parent 3fa2ceb193
commit fba298fecb
2 changed files with 17 additions and 9 deletions

View File

@@ -530,7 +530,7 @@ class LearnerImpl : public Learner {
void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
unsigned ntree_limit = 0) const {
CHECK(gbm_ != nullptr)
<< "Predict must happen after Load or InitModel";
<< "Predict must happen after Load or configuration";
this->ValidateDMatrix(data);
gbm_->PredictBatch(data, out_preds, ntree_limit);
}
@@ -619,11 +619,11 @@ class LearnerImpl : public Learner {
void ValidateDMatrix(DMatrix* p_fmat) const {
MetaInfo const& info = p_fmat->Info();
auto const& weights = info.weights_.HostVector();
if (info.group_ptr_.size() != 0 && weights.size() != 0) {
CHECK(weights.size() == info.group_ptr_.size() - 1)
auto const& weights = info.weights_;
if (info.group_ptr_.size() != 0 && weights.Size() != 0) {
CHECK(weights.Size() == info.group_ptr_.size() - 1)
<< "\n"
<< "weights size: " << weights.size() << ", "
<< "weights size: " << weights.Size() << ", "
<< "groups size: " << info.group_ptr_.size() -1 << ", "
<< "num rows: " << p_fmat->Info().num_row_ << "\n"
<< "Number of weights should be equal to number of groups in ranking task.";