Enable loading model from <1.0.0 trained with objective='binary:logitraw' (#6517)

* Enable loading model from <1.0.0 trained with objective='binary:logitraw'

* Add binary:logitraw in model compatibility testing suite

* Feedback from @trivialfis: Override ProbToMargin() for LogisticRaw

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Philip Hyunsu Cho
2020-12-16 16:53:46 -08:00
committed by GitHub
parent bf6cfe3b99
commit ad1a527709
5 changed files with 45 additions and 21 deletions

View File

@@ -2,7 +2,6 @@
# of saved model files from XGBoost version 0.90 and 1.0.x.
library(xgboost)
library(Matrix)
source('./generate_models_params.R')
set.seed(0)
metadata <- list(
@@ -53,11 +52,16 @@ generate_logistic_model <- function () {
y <- sample(0:1, size = metadata$kRows, replace = TRUE)
stopifnot(max(y) == 1, min(y) == 0)
data <- xgb.DMatrix(X, label = y, weight = w)
params <- list(tree_method = 'hist', num_parallel_tree = metadata$kForests,
max_depth = metadata$kMaxDepth, objective = 'binary:logistic')
booster <- xgb.train(params, data, nrounds = metadata$kRounds)
save_booster(booster, 'logit')
objective <- c('binary:logistic', 'binary:logitraw')
name <- c('logit', 'logitraw')
for (i in seq_len(length(objective))) {
data <- xgb.DMatrix(X, label = y, weight = w)
params <- list(tree_method = 'hist', num_parallel_tree = metadata$kForests,
max_depth = metadata$kMaxDepth, objective = objective[i])
booster <- xgb.train(params, data, nrounds = metadata$kRounds)
save_booster(booster, name[i])
}
}
generate_classification_model <- function () {

View File

@@ -39,6 +39,10 @@ run_booster_check <- function (booster, name) {
testthat::expect_equal(config$learner$learner_train_param$objective, 'multi:softmax')
testthat::expect_equal(as.numeric(config$learner$learner_model_param$num_class),
metadata$kClasses)
} else if (name == 'logitraw') {
testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
testthat::expect_equal(as.numeric(config$learner$learner_model_param$num_class), 0)
testthat::expect_equal(config$learner$learner_train_param$objective, 'binary:logitraw')
} else if (name == 'logit') {
testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
testthat::expect_equal(as.numeric(config$learner$learner_model_param$num_class), 0)