xgboost/R-package/tests/testthat/test_model_compatibility.R
david-cortes d3a8d284ab
[R] On-demand serialization + standardization of attributes (#9924)
---------

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
2024-01-11 05:08:42 +08:00

97 lines
3.9 KiB
R

context("Models from previous versions of XGBoost can be loaded")
metadata <- list(
kRounds = 2,
kRows = 1000,
kCols = 4,
kForests = 2,
kMaxDepth = 2,
kClasses = 3
)
run_model_param_check <- function(config) {
testthat::expect_equal(config$learner$learner_model_param$num_feature, '4')
testthat::expect_equal(config$learner$learner_train_param$booster, 'gbtree')
}
get_num_tree <- function(booster) {
dump <- xgb.dump(booster)
m <- regexec('booster\\[[0-9]+\\]', dump, perl = TRUE)
m <- regmatches(dump, m)
num_tree <- Reduce('+', lapply(m, length))
return(num_tree)
}
run_booster_check <- function(booster, name) {
config <- xgb.config(booster)
run_model_param_check(config)
if (name == 'cls') {
testthat::expect_equal(get_num_tree(booster),
metadata$kForests * metadata$kRounds * metadata$kClasses)
testthat::expect_equal(as.numeric(config$learner$learner_model_param$base_score), 0.5)
testthat::expect_equal(config$learner$learner_train_param$objective, 'multi:softmax')
testthat::expect_equal(as.numeric(config$learner$learner_model_param$num_class),
metadata$kClasses)
} else if (name == 'logitraw') {
testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
testthat::expect_equal(as.numeric(config$learner$learner_model_param$num_class), 0)
testthat::expect_equal(config$learner$learner_train_param$objective, 'binary:logitraw')
} else if (name == 'logit') {
testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
testthat::expect_equal(as.numeric(config$learner$learner_model_param$num_class), 0)
testthat::expect_equal(config$learner$learner_train_param$objective, 'binary:logistic')
} else if (name == 'ltr') {
testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
testthat::expect_equal(config$learner$learner_train_param$objective, 'rank:ndcg')
} else {
testthat::expect_equal(name, 'reg')
testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
testthat::expect_equal(as.numeric(config$learner$learner_model_param$base_score), 0.5)
testthat::expect_equal(config$learner$learner_train_param$objective, 'reg:squarederror')
}
}
test_that("Models from previous versions of XGBoost can be loaded", {
bucket <- 'xgboost-ci-jenkins-artifacts'
region <- 'us-west-2'
file_name <- 'xgboost_r_model_compatibility_test.zip'
zipfile <- tempfile(fileext = ".zip")
extract_dir <- tempdir()
download.file(paste('https://', bucket, '.s3-', region, '.amazonaws.com/', file_name, sep = ''),
destfile = zipfile, mode = 'wb', quiet = TRUE)
unzip(zipfile, exdir = extract_dir, overwrite = TRUE)
model_dir <- file.path(extract_dir, 'models')
pred_data <- xgb.DMatrix(matrix(c(0, 0, 0, 0), nrow = 1, ncol = 4), nthread = 2)
lapply(list.files(model_dir), function(x) {
model_file <- file.path(model_dir, x)
m <- regexec("xgboost-([0-9\\.]+)\\.([a-z]+)\\.[a-z]+", model_file, perl = TRUE)
m <- regmatches(model_file, m)[[1]]
model_xgb_ver <- m[2]
name <- m[3]
is_rds <- endsWith(model_file, '.rds')
is_json <- endsWith(model_file, '.json')
# TODO: update this test for new RDS format
if (is_rds) {
return(NULL)
}
# Expect an R warning when a model is loaded from RDS and it was generated by version < 1.1.x
if (is_rds && compareVersion(model_xgb_ver, '1.1.1.1') < 0) {
booster <- readRDS(model_file)
expect_warning(predict(booster, newdata = pred_data))
booster <- readRDS(model_file)
expect_warning(run_booster_check(booster, name))
} else {
if (is_rds) {
booster <- readRDS(model_file)
} else {
booster <- xgb.load(model_file)
xgb.parameters(booster) <- list(nthread = 2)
}
predict(booster, newdata = pred_data)
run_booster_check(booster, name)
}
})
})