[R] fix CSR input. (#8673)

This commit is contained in:
Jiaming Yuan 2023-01-14 01:32:41 +08:00 committed by GitHub
parent 72ec0c5484
commit b2b6a8aa39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 5 deletions

View File

@ -205,11 +205,11 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP
if (DMLC_LITTLE_ENDIAN) {
jindptr["typestr"] = String{"<i4"};
jindices["typestr"] = String{"<i4"};
jdata["typestr"] = String{"<i8"};
jdata["typestr"] = String{"<f8"};
} else {
jindptr["typestr"] = String{">i4"};
jindices["typestr"] = String{">i4"};
jdata["typestr"] = String{">i8"};
jdata["typestr"] = String{">f8"};
}
std::string indptr, indices, data;
Json::Dump(jindptr, &indptr);

View File

@ -1,6 +1,7 @@
library(Matrix)
context("testing xgb.DMatrix functionality")
data(agaricus.test, package = 'xgboost')
data(agaricus.test, package = "xgboost")
test_data <- agaricus.test$data[1:100, ]
test_label <- agaricus.test$label[1:100]
@ -10,14 +11,49 @@ test_that("xgb.DMatrix: basic construction", {
# from dense matrix
dtest2 <- xgb.DMatrix(as.matrix(test_data), label = test_label)
expect_equal(getinfo(dtest1, 'label'), getinfo(dtest2, 'label'))
expect_equal(getinfo(dtest1, "label"), getinfo(dtest2, "label"))
expect_equal(dim(dtest1), dim(dtest2))
#from dense integer matrix
# from dense integer matrix
int_data <- as.matrix(test_data)
storage.mode(int_data) <- "integer"
dtest3 <- xgb.DMatrix(int_data, label = test_label)
expect_equal(dim(dtest1), dim(dtest3))
n_samples <- 100
X <- cbind(
x1 = rnorm(n_samples),
x2 = rnorm(n_samples),
x3 = rnorm(n_samples)
)
X <- matrix(X, nrow = n_samples)
y <- rbinom(n = n_samples, size = 1, prob = 1 / 2)
fd <- xgb.DMatrix(X, label = y)
dgc <- as(X, "dgCMatrix")
fdgc <- xgb.DMatrix(dgc, label = y)
dgr <- as(X, "dgRMatrix")
fdgr <- xgb.DMatrix(dgr, label = y)
params <- list(tree_method = "hist")
bst_fd <- xgb.train(
params, nrounds = 8, fd, watchlist = list(train = fd)
)
bst_dgr <- xgb.train(
params, nrounds = 8, fdgr, watchlist = list(train = fdgr)
)
bst_dgc <- xgb.train(
params, nrounds = 8, fdgc, watchlist = list(train = fdgc)
)
raw_fd <- xgb.save.raw(bst_fd, raw_format = "ubj")
raw_dgr <- xgb.save.raw(bst_dgr, raw_format = "ubj")
raw_dgc <- xgb.save.raw(bst_dgc, raw_format = "ubj")
expect_equal(raw_fd, raw_dgr)
expect_equal(raw_fd, raw_dgc)
})
test_that("xgb.DMatrix: saving, loading", {