[R] Accept CSR data for predictions (#7615)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
require(xgboost)
|
||||
library(Matrix)
|
||||
|
||||
context("basic functions")
|
||||
|
||||
@@ -459,3 +460,18 @@ test_that("strict_shape works", {
|
||||
test_iris()
|
||||
test_agaricus()
|
||||
})
|
||||
|
||||
test_that("'predict' accepts CSR data", {
|
||||
X <- agaricus.train$data
|
||||
y <- agaricus.train$label
|
||||
x_csc <- as(X[1L, , drop = FALSE], "CsparseMatrix")
|
||||
x_csr <- as(x_csc, "RsparseMatrix")
|
||||
x_spv <- as(x_csc, "sparseVector")
|
||||
bst <- xgboost(data = X, label = y, objective = "binary:logistic",
|
||||
nrounds = 5L, verbose = FALSE)
|
||||
p_csc <- predict(bst, x_csc)
|
||||
p_csr <- predict(bst, x_csr)
|
||||
p_spv <- predict(bst, x_spv)
|
||||
expect_equal(p_csc, p_csr)
|
||||
expect_equal(p_csc, p_spv)
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user