108 lines
4.6 KiB
R
108 lines
4.6 KiB
R
context("Models from previous versions of XGBoost can be loaded")
|
|
|
|
metadata <- list(
|
|
kRounds = 2,
|
|
kRows = 1000,
|
|
kCols = 4,
|
|
kForests = 2,
|
|
kMaxDepth = 2,
|
|
kClasses = 3
|
|
)
|
|
|
|
run_model_param_check <- function (config) {
|
|
testthat::expect_equal(config$learner$learner_model_param$num_feature, '4')
|
|
testthat::expect_equal(config$learner$learner_train_param$booster, 'gbtree')
|
|
}
|
|
|
|
get_num_tree <- function (booster) {
|
|
dump <- xgb.dump(booster)
|
|
m <- regexec('booster\\[[0-9]+\\]', dump, perl = TRUE)
|
|
m <- regmatches(dump, m)
|
|
num_tree <- Reduce('+', lapply(m, length))
|
|
return (num_tree)
|
|
}
|
|
|
|
run_booster_check <- function (booster, name) {
|
|
# If given a handle, we need to call xgb.Booster.complete() prior to using xgb.config().
|
|
if (inherits(booster, "xgb.Booster") && xgboost:::is.null.handle(booster$handle)) {
|
|
booster <- xgb.Booster.complete(booster)
|
|
}
|
|
config <- jsonlite::fromJSON(xgb.config(booster))
|
|
run_model_param_check(config)
|
|
if (name == 'cls') {
|
|
testthat::expect_equal(get_num_tree(booster),
|
|
metadata$kForests * metadata$kRounds * metadata$kClasses)
|
|
testthat::expect_equal(as.numeric(config$learner$learner_model_param$base_score), 0.5)
|
|
testthat::expect_equal(config$learner$learner_train_param$objective, 'multi:softmax')
|
|
testthat::expect_equal(as.numeric(config$learner$learner_model_param$num_class),
|
|
metadata$kClasses)
|
|
} else if (name == 'logitraw') {
|
|
testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
|
|
testthat::expect_equal(as.numeric(config$learner$learner_model_param$num_class), 0)
|
|
testthat::expect_equal(config$learner$learner_train_param$objective, 'binary:logitraw')
|
|
} else if (name == 'logit') {
|
|
testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
|
|
testthat::expect_equal(as.numeric(config$learner$learner_model_param$num_class), 0)
|
|
testthat::expect_equal(config$learner$learner_train_param$objective, 'binary:logistic')
|
|
} else if (name == 'ltr') {
|
|
testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
|
|
testthat::expect_equal(config$learner$learner_train_param$objective, 'rank:ndcg')
|
|
} else {
|
|
testthat::expect_equal(name, 'reg')
|
|
testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
|
|
testthat::expect_equal(as.numeric(config$learner$learner_model_param$base_score), 0.5)
|
|
testthat::expect_equal(config$learner$learner_train_param$objective, 'reg:squarederror')
|
|
}
|
|
}
|
|
|
|
test_that("Models from previous versions of XGBoost can be loaded", {
|
|
bucket <- 'xgboost-ci-jenkins-artifacts'
|
|
region <- 'us-west-2'
|
|
file_name <- 'xgboost_r_model_compatibility_test.zip'
|
|
zipfile <- tempfile(fileext = ".zip")
|
|
extract_dir <- tempdir()
|
|
download.file(paste('https://', bucket, '.s3-', region, '.amazonaws.com/', file_name, sep = ''),
|
|
destfile = zipfile, mode = 'wb', quiet = TRUE)
|
|
unzip(zipfile, exdir = extract_dir, overwrite = TRUE)
|
|
model_dir <- file.path(extract_dir, 'models')
|
|
|
|
pred_data <- xgb.DMatrix(matrix(c(0, 0, 0, 0), nrow = 1, ncol = 4))
|
|
|
|
lapply(list.files(model_dir), function (x) {
|
|
model_file <- file.path(model_dir, x)
|
|
m <- regexec("xgboost-([0-9\\.]+)\\.([a-z]+)\\.[a-z]+", model_file, perl = TRUE)
|
|
m <- regmatches(model_file, m)[[1]]
|
|
model_xgb_ver <- m[2]
|
|
name <- m[3]
|
|
is_rds <- endsWith(model_file, '.rds')
|
|
is_json <- endsWith(model_file, '.json')
|
|
|
|
cpp_warning <- capture.output({
|
|
# Expect an R warning when a model is loaded from RDS and it was generated by version < 1.1.x
|
|
if (is_rds && compareVersion(model_xgb_ver, '1.1.1.1') < 0) {
|
|
booster <- readRDS(model_file)
|
|
expect_warning(predict(booster, newdata = pred_data))
|
|
booster <- readRDS(model_file)
|
|
expect_warning(run_booster_check(booster, name))
|
|
} else {
|
|
if (is_rds) {
|
|
booster <- readRDS(model_file)
|
|
} else {
|
|
booster <- xgb.load(model_file)
|
|
}
|
|
predict(booster, newdata = pred_data)
|
|
run_booster_check(booster, name)
|
|
}
|
|
})
|
|
cpp_warning <- paste0(cpp_warning, collapse = ' ')
|
|
if (is_rds && compareVersion(model_xgb_ver, '1.1.1.1') >= 0) {
|
|
# Expect a C++ warning when a model is loaded from RDS and it was generated by old XGBoost`
|
|
m <- grepl(paste0('.*If you are loading a serialized model ',
|
|
'\\(like pickle in Python, RDS in R\\).*',
|
|
'for more details about differences between ',
|
|
'saving model and serializing.*'), cpp_warning, perl = TRUE)
|
|
expect_true(length(m) > 0 && all(m))
|
|
}
|
|
})
|
|
})
|