Fix R dart prediction. (#5204)

* Fix R dart prediction and add test.
This commit is contained in:
Jiaming Yuan 2020-01-16 12:11:04 +08:00 committed by GitHub
parent 808f61081b
commit 5199b86126
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 80 additions and 35 deletions

View File

@ -313,7 +313,7 @@ SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask,
R_ExternalPtrAddr(dmat), R_ExternalPtrAddr(dmat),
asInteger(option_mask), asInteger(option_mask),
asInteger(ntree_limit), asInteger(ntree_limit),
0, asInteger(training),
&olen, &res)); &olen, &res));
ret = PROTECT(allocVector(REALSXP, olen)); ret = PROTECT(allocVector(REALSXP, olen));
for (size_t i = 0; i < olen; ++i) { for (size_t i = 0; i < olen; ++i) {

View File

@ -35,6 +35,54 @@ test_that("train and predict binary classification", {
expect_lt(abs(err_pred1 - err_log), 10e-6) expect_lt(abs(err_pred1 - err_log), 10e-6)
}) })
test_that("dart prediction works", {
nrounds = 32
set.seed(1994)
d <- cbind(
x1 = rnorm(100),
x2 = rnorm(100),
x3 = rnorm(100))
y <- d[,"x1"] + d[,"x2"]^2 +
ifelse(d[,"x3"] > .5, d[,"x3"]^2, 2^d[,"x3"]) +
rnorm(100)
set.seed(1994)
booster_by_xgboost <- xgboost(data = d, label = y, max_depth = 2, booster = "dart",
rate_drop = 0.5, one_drop = TRUE,
eta = 1, nthread = 2, nrounds = nrounds, objective = "reg:squarederror")
pred_by_xgboost_0 <- predict(booster_by_xgboost, newdata = d, ntreelimit = 0)
pred_by_xgboost_1 <- predict(booster_by_xgboost, newdata = d, ntreelimit = nrounds)
expect_true(all(matrix(pred_by_xgboost_0, byrow=TRUE) == matrix(pred_by_xgboost_1, byrow=TRUE)))
pred_by_xgboost_2 <- predict(booster_by_xgboost, newdata = d, training = TRUE)
expect_false(all(matrix(pred_by_xgboost_0, byrow=TRUE) == matrix(pred_by_xgboost_2, byrow=TRUE)))
set.seed(1994)
dtrain <- xgb.DMatrix(data=d, info = list(label=y))
booster_by_train <- xgb.train( params = list(
booster = "dart",
max_depth = 2,
eta = 1,
rate_drop = 0.5,
one_drop = TRUE,
nthread = 1,
tree_method= "exact",
verbosity = 3,
objective = "reg:squarederror"
),
data = dtrain,
nrounds = nrounds
)
pred_by_train_0 <- predict(booster_by_train, newdata = dtrain, ntreelimit = 0)
pred_by_train_1 <- predict(booster_by_train, newdata = dtrain, ntreelimit = nrounds)
pred_by_train_2 <- predict(booster_by_train, newdata = dtrain, training = TRUE)
expect_true(all(matrix(pred_by_train_0, byrow=TRUE) == matrix(pred_by_xgboost_0, byrow=TRUE)))
expect_true(all(matrix(pred_by_train_1, byrow=TRUE) == matrix(pred_by_xgboost_1, byrow=TRUE)))
expect_true(all(matrix(pred_by_train_2, byrow=TRUE) == matrix(pred_by_xgboost_2, byrow=TRUE)))
})
test_that("train and predict softprob", { test_that("train and predict softprob", {
lb <- as.numeric(iris$Species) - 1 lb <- as.numeric(iris$Species) - 1
set.seed(11) set.seed(11)

View File

@ -157,7 +157,7 @@ test_that("SHAPs sum to predictions, with or without DART", {
params = c( params = c(
list( list(
booster = booster, booster = booster,
objective = "reg:linear", objective = "reg:squarederror",
eval_metric = "rmse"), eval_metric = "rmse"),
if (booster == "dart") if (booster == "dart")
list(rate_drop = .01, one_drop = T)), list(rate_drop = .01, one_drop = T)),

View File

@ -435,9 +435,9 @@ class Dart : public GBTree {
std::fill(out_preds.begin(), out_preds.end(), std::fill(out_preds.begin(), out_preds.end(),
model_.learner_model_param_->base_score); model_.learner_model_param_->base_score);
} }
const int nthread = omp_get_max_threads();
PredLoopSpecalize(p_fmat, &out_preds, num_group, 0, InitThreadTemp(nthread);
ntree_limit, training); PredLoopSpecalize(p_fmat, &out_preds, num_group, 0, ntree_limit);
} }
void PredictInstance(const SparsePage::Inst &inst, void PredictInstance(const SparsePage::Inst &inst,
@ -489,11 +489,8 @@ class Dart : public GBTree {
std::vector<bst_float>* out_preds, std::vector<bst_float>* out_preds,
int num_group, int num_group,
unsigned tree_begin, unsigned tree_begin,
unsigned tree_end, unsigned tree_end) {
bool training) {
const int nthread = omp_get_max_threads();
CHECK_EQ(num_group, model_.learner_model_param_->num_output_group); CHECK_EQ(num_group, model_.learner_model_param_->num_output_group);
InitThreadTemp(nthread);
std::vector<bst_float>& preds = *out_preds; std::vector<bst_float>& preds = *out_preds;
CHECK_EQ(model_.param.size_leaf_vector, 0) CHECK_EQ(model_.param.size_leaf_vector, 0)
<< "size_leaf_vector is enforced to 0 so far"; << "size_leaf_vector is enforced to 0 so far";