[R] fix CSR input. (#8673)
This commit is contained in:
parent
72ec0c5484
commit
b2b6a8aa39
@ -205,11 +205,11 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP
|
||||
if (DMLC_LITTLE_ENDIAN) {
|
||||
jindptr["typestr"] = String{"<i4"};
|
||||
jindices["typestr"] = String{"<i4"};
|
||||
jdata["typestr"] = String{"<i8"};
|
||||
jdata["typestr"] = String{"<f8"};
|
||||
} else {
|
||||
jindptr["typestr"] = String{">i4"};
|
||||
jindices["typestr"] = String{">i4"};
|
||||
jdata["typestr"] = String{">i8"};
|
||||
jdata["typestr"] = String{">f8"};
|
||||
}
|
||||
std::string indptr, indices, data;
|
||||
Json::Dump(jindptr, &indptr);
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
library(Matrix)
|
||||
context("testing xgb.DMatrix functionality")
|
||||
|
||||
data(agaricus.test, package = 'xgboost')
|
||||
data(agaricus.test, package = "xgboost")
|
||||
test_data <- agaricus.test$data[1:100, ]
|
||||
test_label <- agaricus.test$label[1:100]
|
||||
|
||||
@ -10,14 +11,49 @@ test_that("xgb.DMatrix: basic construction", {
|
||||
|
||||
# from dense matrix
|
||||
dtest2 <- xgb.DMatrix(as.matrix(test_data), label = test_label)
|
||||
expect_equal(getinfo(dtest1, 'label'), getinfo(dtest2, 'label'))
|
||||
expect_equal(getinfo(dtest1, "label"), getinfo(dtest2, "label"))
|
||||
expect_equal(dim(dtest1), dim(dtest2))
|
||||
|
||||
#from dense integer matrix
|
||||
# from dense integer matrix
|
||||
int_data <- as.matrix(test_data)
|
||||
storage.mode(int_data) <- "integer"
|
||||
dtest3 <- xgb.DMatrix(int_data, label = test_label)
|
||||
expect_equal(dim(dtest1), dim(dtest3))
|
||||
|
||||
n_samples <- 100
|
||||
X <- cbind(
|
||||
x1 = rnorm(n_samples),
|
||||
x2 = rnorm(n_samples),
|
||||
x3 = rnorm(n_samples)
|
||||
)
|
||||
X <- matrix(X, nrow = n_samples)
|
||||
y <- rbinom(n = n_samples, size = 1, prob = 1 / 2)
|
||||
|
||||
fd <- xgb.DMatrix(X, label = y)
|
||||
|
||||
dgc <- as(X, "dgCMatrix")
|
||||
fdgc <- xgb.DMatrix(dgc, label = y)
|
||||
|
||||
dgr <- as(X, "dgRMatrix")
|
||||
fdgr <- xgb.DMatrix(dgr, label = y)
|
||||
|
||||
params <- list(tree_method = "hist")
|
||||
bst_fd <- xgb.train(
|
||||
params, nrounds = 8, fd, watchlist = list(train = fd)
|
||||
)
|
||||
bst_dgr <- xgb.train(
|
||||
params, nrounds = 8, fdgr, watchlist = list(train = fdgr)
|
||||
)
|
||||
bst_dgc <- xgb.train(
|
||||
params, nrounds = 8, fdgc, watchlist = list(train = fdgc)
|
||||
)
|
||||
|
||||
raw_fd <- xgb.save.raw(bst_fd, raw_format = "ubj")
|
||||
raw_dgr <- xgb.save.raw(bst_dgr, raw_format = "ubj")
|
||||
raw_dgc <- xgb.save.raw(bst_dgc, raw_format = "ubj")
|
||||
|
||||
expect_equal(raw_fd, raw_dgr)
|
||||
expect_equal(raw_fd, raw_dgc)
|
||||
})
|
||||
|
||||
test_that("xgb.DMatrix: saving, loading", {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user