Fix R dart prediction. (#5204)

* Fix R dart prediction and add test.
This commit is contained in:
Jiaming Yuan
2020-01-16 12:11:04 +08:00
committed by GitHub
parent 808f61081b
commit 5199b86126
5 changed files with 80 additions and 35 deletions

View File

@@ -435,9 +435,9 @@ class Dart : public GBTree {
std::fill(out_preds.begin(), out_preds.end(),
model_.learner_model_param_->base_score);
}
PredLoopSpecalize(p_fmat, &out_preds, num_group, 0,
ntree_limit, training);
const int nthread = omp_get_max_threads();
InitThreadTemp(nthread);
PredLoopSpecalize(p_fmat, &out_preds, num_group, 0, ntree_limit);
}
void PredictInstance(const SparsePage::Inst &inst,
@@ -489,11 +489,8 @@ class Dart : public GBTree {
std::vector<bst_float>* out_preds,
int num_group,
unsigned tree_begin,
unsigned tree_end,
bool training) {
const int nthread = omp_get_max_threads();
unsigned tree_end) {
CHECK_EQ(num_group, model_.learner_model_param_->num_output_group);
InitThreadTemp(nthread);
std::vector<bst_float>& preds = *out_preds;
CHECK_EQ(model_.param.size_leaf_vector, 0)
<< "size_leaf_vector is enforced to 0 so far";