Fix GPU ID and prediction cache from pickle (#5086)
* Hack for saving GPU ID. * Declare prediction cache on GBTree. * Add a simple test. * Add `auto` option for GPU Predictor.
This commit is contained in:
@@ -29,21 +29,17 @@ TEST(GBTree, SelectTreeMethod) {
|
||||
ASSERT_EQ(tparam.updater_seq, "grow_colmaker,prune");
|
||||
gbtree.Configure({{"tree_method", "hist"}, {"num_feature", n_feat}});
|
||||
ASSERT_EQ(tparam.updater_seq, "grow_quantile_histmaker");
|
||||
ASSERT_EQ(tparam.predictor, "cpu_predictor");
|
||||
gbtree.Configure({{"booster", "dart"}, {"tree_method", "hist"},
|
||||
{"num_feature", n_feat}});
|
||||
ASSERT_EQ(tparam.updater_seq, "grow_quantile_histmaker");
|
||||
ASSERT_EQ(tparam.predictor, "cpu_predictor");
|
||||
|
||||
#ifdef XGBOOST_USE_CUDA
|
||||
generic_param.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
|
||||
gbtree.Configure({{"tree_method", "gpu_hist"}, {"num_feature", n_feat}});
|
||||
ASSERT_EQ(tparam.updater_seq, "grow_gpu_hist");
|
||||
ASSERT_EQ(tparam.predictor, "gpu_predictor");
|
||||
gbtree.Configure({{"booster", "dart"}, {"tree_method", "gpu_hist"},
|
||||
{"num_feature", n_feat}});
|
||||
ASSERT_EQ(tparam.updater_seq, "grow_gpu_hist");
|
||||
ASSERT_EQ(tparam.predictor, "gpu_predictor");
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user