[R] Use inplace predict (#9829)
--------- Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -139,8 +139,8 @@ test_that("dart prediction works", {
|
||||
pred_by_train_1 <- predict(booster_by_train, newdata = dtrain, iterationrange = c(1, nrounds))
|
||||
pred_by_train_2 <- predict(booster_by_train, newdata = dtrain, training = TRUE)
|
||||
|
||||
expect_true(all(matrix(pred_by_train_0, byrow = TRUE) == matrix(pred_by_xgboost_0, byrow = TRUE)))
|
||||
expect_true(all(matrix(pred_by_train_1, byrow = TRUE) == matrix(pred_by_xgboost_1, byrow = TRUE)))
|
||||
expect_equal(pred_by_train_0, pred_by_xgboost_0, tolerance = 1e-6)
|
||||
expect_equal(pred_by_train_1, pred_by_xgboost_1, tolerance = 1e-6)
|
||||
expect_true(all(matrix(pred_by_train_2, byrow = TRUE) == matrix(pred_by_xgboost_2, byrow = TRUE)))
|
||||
})
|
||||
|
||||
@@ -651,6 +651,51 @@ test_that("Can use ranking objectives with either 'qid' or 'group'", {
|
||||
expect_equal(pred_qid, pred_gr)
|
||||
})
|
||||
|
||||
test_that("Can predict on data.frame objects", {
|
||||
data("mtcars")
|
||||
y <- mtcars$mpg
|
||||
x_df <- mtcars[, -1]
|
||||
x_mat <- as.matrix(x_df)
|
||||
dm <- xgb.DMatrix(x_mat, label = y, nthread = n_threads)
|
||||
model <- xgb.train(
|
||||
params = list(
|
||||
tree_method = "hist",
|
||||
objective = "reg:squarederror",
|
||||
nthread = n_threads
|
||||
),
|
||||
data = dm,
|
||||
nrounds = 5
|
||||
)
|
||||
|
||||
pred_mat <- predict(model, xgb.DMatrix(x_mat), nthread = n_threads)
|
||||
pred_df <- predict(model, x_df, nthread = n_threads)
|
||||
expect_equal(pred_mat, pred_df)
|
||||
})
|
||||
|
||||
test_that("'base_margin' gives the same result in DMatrix as in inplace_predict", {
|
||||
data("mtcars")
|
||||
y <- mtcars$mpg
|
||||
x <- as.matrix(mtcars[, -1])
|
||||
dm <- xgb.DMatrix(x, label = y, nthread = n_threads)
|
||||
model <- xgb.train(
|
||||
params = list(
|
||||
tree_method = "hist",
|
||||
objective = "reg:squarederror",
|
||||
nthread = n_threads
|
||||
),
|
||||
data = dm,
|
||||
nrounds = 5
|
||||
)
|
||||
|
||||
set.seed(123)
|
||||
base_margin <- rnorm(nrow(x))
|
||||
dm_w_base <- xgb.DMatrix(data = x, base_margin = base_margin)
|
||||
pred_from_dm <- predict(model, dm_w_base)
|
||||
pred_from_mat <- predict(model, x, base_margin = base_margin)
|
||||
|
||||
expect_equal(pred_from_dm, pred_from_mat)
|
||||
})
|
||||
|
||||
test_that("Coefficients from gblinear have the expected shape and names", {
|
||||
# Single-column coefficients
|
||||
data(mtcars)
|
||||
|
||||
@@ -302,6 +302,37 @@ test_that("xgb.DMatrix: Inf as missing", {
|
||||
file.remove(fname_nan)
|
||||
})
|
||||
|
||||
test_that("xgb.DMatrix: missing in CSR", {
|
||||
x_dense <- matrix(as.numeric(1:10), nrow = 5)
|
||||
x_dense[2, 1] <- NA_real_
|
||||
|
||||
x_csr <- as(x_dense, "RsparseMatrix")
|
||||
|
||||
m_dense <- xgb.DMatrix(x_dense, nthread = n_threads, missing = NA_real_)
|
||||
xgb.DMatrix.save(m_dense, "dense.dmatrix")
|
||||
|
||||
m_csr <- xgb.DMatrix(x_csr, nthread = n_threads, missing = NA)
|
||||
xgb.DMatrix.save(m_csr, "csr.dmatrix")
|
||||
|
||||
denseconn <- file("dense.dmatrix", "rb")
|
||||
csrconn <- file("csr.dmatrix", "rb")
|
||||
|
||||
expect_equal(file.size("dense.dmatrix"), file.size("csr.dmatrix"))
|
||||
|
||||
bytes <- file.size("dense.dmatrix")
|
||||
densedmatrix <- readBin(denseconn, "raw", n = bytes)
|
||||
csrmatrix <- readBin(csrconn, "raw", n = bytes)
|
||||
|
||||
expect_equal(length(densedmatrix), length(csrmatrix))
|
||||
expect_equal(densedmatrix, csrmatrix)
|
||||
|
||||
close(denseconn)
|
||||
close(csrconn)
|
||||
|
||||
file.remove("dense.dmatrix")
|
||||
file.remove("csr.dmatrix")
|
||||
})
|
||||
|
||||
test_that("xgb.DMatrix: error on three-dimensional array", {
|
||||
set.seed(123)
|
||||
x <- matrix(rnorm(500), nrow = 50)
|
||||
|
||||
Reference in New Issue
Block a user