[R] Accept CSR data for predictions (#7615)

This commit is contained in:
david-cortes
2022-01-29 18:54:57 +02:00
committed by GitHub
parent 549bd419bb
commit 7f738e7f6f
8 changed files with 93 additions and 7 deletions

View File

@@ -1,4 +1,5 @@
require(xgboost)
library(Matrix)
context("basic functions")
@@ -459,3 +460,18 @@ test_that("strict_shape works", {
test_iris()
test_agaricus()
})
test_that("'predict' accepts CSR data", {
X <- agaricus.train$data
y <- agaricus.train$label
x_csc <- as(X[1L, , drop = FALSE], "CsparseMatrix")
x_csr <- as(x_csc, "RsparseMatrix")
x_spv <- as(x_csc, "sparseVector")
bst <- xgboost(data = X, label = y, objective = "binary:logistic",
nrounds = 5L, verbose = FALSE)
p_csc <- predict(bst, x_csc)
p_csr <- predict(bst, x_csr)
p_spv <- predict(bst, x_spv)
expect_equal(p_csc, p_csr)
expect_equal(p_csc, p_spv)
})