Fix GPU ID and prediction cache from pickle (#5086)

* Hack for saving GPU ID.

* Declare prediction cache on GBTree.

* Add a simple test.

* Add `auto` option for GPU Predictor.
This commit is contained in:
Jiaming Yuan
2019-12-07 16:02:06 +08:00
committed by GitHub
parent 7ef5b78003
commit 608ebbe444
17 changed files with 362 additions and 182 deletions

View File

@@ -29,21 +29,17 @@ TEST(GBTree, SelectTreeMethod) {
ASSERT_EQ(tparam.updater_seq, "grow_colmaker,prune");
gbtree.Configure({{"tree_method", "hist"}, {"num_feature", n_feat}});
ASSERT_EQ(tparam.updater_seq, "grow_quantile_histmaker");
ASSERT_EQ(tparam.predictor, "cpu_predictor");
gbtree.Configure({{"booster", "dart"}, {"tree_method", "hist"},
{"num_feature", n_feat}});
ASSERT_EQ(tparam.updater_seq, "grow_quantile_histmaker");
ASSERT_EQ(tparam.predictor, "cpu_predictor");
#ifdef XGBOOST_USE_CUDA
generic_param.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
gbtree.Configure({{"tree_method", "gpu_hist"}, {"num_feature", n_feat}});
ASSERT_EQ(tparam.updater_seq, "grow_gpu_hist");
ASSERT_EQ(tparam.predictor, "gpu_predictor");
gbtree.Configure({{"booster", "dart"}, {"tree_method", "gpu_hist"},
{"num_feature", n_feat}});
ASSERT_EQ(tparam.updater_seq, "grow_gpu_hist");
ASSERT_EQ(tparam.predictor, "gpu_predictor");
#endif
}