From b2b6a8aa39a31fd883efcc916cd0c9aaaa758ba5 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 14 Jan 2023 01:32:41 +0800 Subject: [PATCH] [R] fix CSR input. (#8673) --- R-package/src/xgboost_R.cc | 4 +-- R-package/tests/testthat/test_dmatrix.R | 42 +++++++++++++++++++++++-- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 8f775f087..a6732d9ed 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -205,11 +205,11 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP if (DMLC_LITTLE_ENDIAN) { jindptr["typestr"] = String{"i4"}; jindices["typestr"] = String{">i4"}; - jdata["typestr"] = String{">i8"}; + jdata["typestr"] = String{">f8"}; } std::string indptr, indices, data; Json::Dump(jindptr, &indptr); diff --git a/R-package/tests/testthat/test_dmatrix.R b/R-package/tests/testthat/test_dmatrix.R index 3c922e9a0..da1180f5e 100644 --- a/R-package/tests/testthat/test_dmatrix.R +++ b/R-package/tests/testthat/test_dmatrix.R @@ -1,6 +1,7 @@ +library(Matrix) context("testing xgb.DMatrix functionality") -data(agaricus.test, package = 'xgboost') +data(agaricus.test, package = "xgboost") test_data <- agaricus.test$data[1:100, ] test_label <- agaricus.test$label[1:100] @@ -10,14 +11,49 @@ test_that("xgb.DMatrix: basic construction", { # from dense matrix dtest2 <- xgb.DMatrix(as.matrix(test_data), label = test_label) - expect_equal(getinfo(dtest1, 'label'), getinfo(dtest2, 'label')) + expect_equal(getinfo(dtest1, "label"), getinfo(dtest2, "label")) expect_equal(dim(dtest1), dim(dtest2)) - #from dense integer matrix + # from dense integer matrix int_data <- as.matrix(test_data) storage.mode(int_data) <- "integer" dtest3 <- xgb.DMatrix(int_data, label = test_label) expect_equal(dim(dtest1), dim(dtest3)) + + n_samples <- 100 + X <- cbind( + x1 = rnorm(n_samples), + x2 = rnorm(n_samples), + x3 = rnorm(n_samples) + ) + X <- matrix(X, nrow = n_samples) + y <- rbinom(n = n_samples, size = 1, prob = 1 / 2) + + fd <- xgb.DMatrix(X, label = y) + + dgc <- as(X, "dgCMatrix") + fdgc <- xgb.DMatrix(dgc, label = y) + + dgr <- as(X, "dgRMatrix") + fdgr <- xgb.DMatrix(dgr, label = y) + + params <- list(tree_method = "hist") + bst_fd <- xgb.train( + params, nrounds = 8, fd, watchlist = list(train = fd) + ) + bst_dgr <- xgb.train( + params, nrounds = 8, fdgr, watchlist = list(train = fdgr) + ) + bst_dgc <- xgb.train( + params, nrounds = 8, fdgc, watchlist = list(train = fdgc) + ) + + raw_fd <- xgb.save.raw(bst_fd, raw_format = "ubj") + raw_dgr <- xgb.save.raw(bst_dgr, raw_format = "ubj") + raw_dgc <- xgb.save.raw(bst_dgc, raw_format = "ubj") + + expect_equal(raw_fd, raw_dgr) + expect_equal(raw_fd, raw_dgc) }) test_that("xgb.DMatrix: saving, loading", {