[breaking] Add prediction fucntion for DMatrix and use inplace predict for dask. (#6668)
* Add a new API function for predicting on `DMatrix`. This function aligns with rest of the `XGBoosterPredictFrom*` functions on semantic of function arguments. * Purge `ntree_limit` from libxgboost, use iteration instead. * [dask] Use `inplace_predict` by default for dask sklearn models. * [dask] Run prediction shape inference on worker instead of client. The breaking change is in the Python sklearn `apply` function, I made it to be consistent with other prediction functions where `best_iteration` is used by default.
This commit is contained in:
@@ -51,6 +51,53 @@ TEST(GBTree, SelectTreeMethod) {
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
}
|
||||
|
||||
TEST(GBTree, PredictionCache) {
|
||||
size_t constexpr kRows = 100, kCols = 10;
|
||||
GenericParameter generic_param;
|
||||
generic_param.UpdateAllowUnknown(Args{});
|
||||
LearnerModelParam mparam;
|
||||
mparam.base_score = 0.5;
|
||||
mparam.num_feature = kCols;
|
||||
mparam.num_output_group = 1;
|
||||
|
||||
std::unique_ptr<GradientBooster> p_gbm {
|
||||
GradientBooster::Create("gbtree", &generic_param, &mparam)};
|
||||
auto& gbtree = dynamic_cast<gbm::GBTree&> (*p_gbm);
|
||||
|
||||
gbtree.Configure({{"tree_method", "hist"}});
|
||||
auto p_m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
|
||||
auto gpair = GenerateRandomGradients(kRows);
|
||||
PredictionCacheEntry out_predictions;
|
||||
gbtree.DoBoost(p_m.get(), &gpair, &out_predictions);
|
||||
|
||||
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0);
|
||||
ASSERT_EQ(1, out_predictions.version);
|
||||
std::vector<float> first_iter = out_predictions.predictions.HostVector();
|
||||
// Add 1 more boosted round
|
||||
gbtree.DoBoost(p_m.get(), &gpair, &out_predictions);
|
||||
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0);
|
||||
ASSERT_EQ(2, out_predictions.version);
|
||||
// Update the cache for all rounds
|
||||
out_predictions.version = 0;
|
||||
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0);
|
||||
ASSERT_EQ(2, out_predictions.version);
|
||||
|
||||
gbtree.DoBoost(p_m.get(), &gpair, &out_predictions);
|
||||
// drop the cache.
|
||||
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 1, 2);
|
||||
ASSERT_EQ(0, out_predictions.version);
|
||||
// half open set [1, 3)
|
||||
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 1, 3);
|
||||
ASSERT_EQ(0, out_predictions.version);
|
||||
// iteration end
|
||||
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 2);
|
||||
ASSERT_EQ(2, out_predictions.version);
|
||||
// restart the cache when end iteration is smaller than cache version
|
||||
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 1);
|
||||
ASSERT_EQ(1, out_predictions.version);
|
||||
ASSERT_EQ(out_predictions.predictions.HostVector(), first_iter);
|
||||
}
|
||||
|
||||
TEST(GBTree, WrongUpdater) {
|
||||
size_t constexpr kRows = 17;
|
||||
size_t constexpr kCols = 15;
|
||||
|
||||
Reference in New Issue
Block a user