- Save the updater sequence as an array instead of object. - Warn only once. The compatibility is kept, but we should be able to break it as the config is not loaded in pickle model and it's declared to be not stable.
96 lines
4.0 KiB
R
96 lines
4.0 KiB
R
context("Models from previous versions of XGBoost can be loaded")
|
|
|
|
metadata <- list(
|
|
kRounds = 2,
|
|
kRows = 1000,
|
|
kCols = 4,
|
|
kForests = 2,
|
|
kMaxDepth = 2,
|
|
kClasses = 3
|
|
)
|
|
|
|
run_model_param_check <- function (config) {
|
|
testthat::expect_equal(config$learner$learner_model_param$num_feature, '4')
|
|
testthat::expect_equal(config$learner$learner_train_param$booster, 'gbtree')
|
|
}
|
|
|
|
get_num_tree <- function (booster) {
|
|
dump <- xgb.dump(booster)
|
|
m <- regexec('booster\\[[0-9]+\\]', dump, perl = TRUE)
|
|
m <- regmatches(dump, m)
|
|
num_tree <- Reduce('+', lapply(m, length))
|
|
return (num_tree)
|
|
}
|
|
|
|
run_booster_check <- function (booster, name) {
|
|
# If given a handle, we need to call xgb.Booster.complete() prior to using xgb.config().
|
|
if (inherits(booster, "xgb.Booster") && xgboost:::is.null.handle(booster$handle)) {
|
|
booster <- xgb.Booster.complete(booster)
|
|
}
|
|
config <- jsonlite::fromJSON(xgb.config(booster))
|
|
run_model_param_check(config)
|
|
if (name == 'cls') {
|
|
testthat::expect_equal(get_num_tree(booster),
|
|
metadata$kForests * metadata$kRounds * metadata$kClasses)
|
|
testthat::expect_equal(as.numeric(config$learner$learner_model_param$base_score), 0.5)
|
|
testthat::expect_equal(config$learner$learner_train_param$objective, 'multi:softmax')
|
|
testthat::expect_equal(as.numeric(config$learner$learner_model_param$num_class),
|
|
metadata$kClasses)
|
|
} else if (name == 'logitraw') {
|
|
testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
|
|
testthat::expect_equal(as.numeric(config$learner$learner_model_param$num_class), 0)
|
|
testthat::expect_equal(config$learner$learner_train_param$objective, 'binary:logitraw')
|
|
} else if (name == 'logit') {
|
|
testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
|
|
testthat::expect_equal(as.numeric(config$learner$learner_model_param$num_class), 0)
|
|
testthat::expect_equal(config$learner$learner_train_param$objective, 'binary:logistic')
|
|
} else if (name == 'ltr') {
|
|
testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
|
|
testthat::expect_equal(config$learner$learner_train_param$objective, 'rank:ndcg')
|
|
} else {
|
|
testthat::expect_equal(name, 'reg')
|
|
testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
|
|
testthat::expect_equal(as.numeric(config$learner$learner_model_param$base_score), 0.5)
|
|
testthat::expect_equal(config$learner$learner_train_param$objective, 'reg:squarederror')
|
|
}
|
|
}
|
|
|
|
test_that("Models from previous versions of XGBoost can be loaded", {
|
|
bucket <- 'xgboost-ci-jenkins-artifacts'
|
|
region <- 'us-west-2'
|
|
file_name <- 'xgboost_r_model_compatibility_test.zip'
|
|
zipfile <- tempfile(fileext = ".zip")
|
|
extract_dir <- tempdir()
|
|
download.file(paste('https://', bucket, '.s3-', region, '.amazonaws.com/', file_name, sep = ''),
|
|
destfile = zipfile, mode = 'wb', quiet = TRUE)
|
|
unzip(zipfile, exdir = extract_dir, overwrite = TRUE)
|
|
model_dir <- file.path(extract_dir, 'models')
|
|
|
|
pred_data <- xgb.DMatrix(matrix(c(0, 0, 0, 0), nrow = 1, ncol = 4))
|
|
|
|
lapply(list.files(model_dir), function (x) {
|
|
model_file <- file.path(model_dir, x)
|
|
m <- regexec("xgboost-([0-9\\.]+)\\.([a-z]+)\\.[a-z]+", model_file, perl = TRUE)
|
|
m <- regmatches(model_file, m)[[1]]
|
|
model_xgb_ver <- m[2]
|
|
name <- m[3]
|
|
is_rds <- endsWith(model_file, '.rds')
|
|
is_json <- endsWith(model_file, '.json')
|
|
# Expect an R warning when a model is loaded from RDS and it was generated by version < 1.1.x
|
|
if (is_rds && compareVersion(model_xgb_ver, '1.1.1.1') < 0) {
|
|
booster <- readRDS(model_file)
|
|
expect_warning(predict(booster, newdata = pred_data))
|
|
booster <- readRDS(model_file)
|
|
expect_warning(run_booster_check(booster, name))
|
|
} else {
|
|
if (is_rds) {
|
|
booster <- readRDS(model_file)
|
|
} else {
|
|
booster <- xgb.load(model_file)
|
|
}
|
|
predict(booster, newdata = pred_data)
|
|
run_booster_check(booster, name)
|
|
}
|
|
})
|
|
})
|