[R] Add missing DMatrix functions (#9929)
* `XGDMatrixGetQuantileCut` * `XGDMatrixNumNonMissing` * `XGDMatrixGetDataAsCSR` --------- Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
@@ -375,3 +375,62 @@ test_that("xgb.DMatrix: can take multi-dimensional 'base_margin'", {
|
||||
)
|
||||
expect_equal(pred_only_x, pred_w_base - b, tolerance = 1e-5)
|
||||
})
|
||||
|
||||
test_that("xgb.DMatrix: number of non-missing matches data", {
|
||||
x <- matrix(1:10, nrow = 5)
|
||||
dm1 <- xgb.DMatrix(x)
|
||||
expect_equal(xgb.get.DMatrix.num.non.missing(dm1), 10)
|
||||
|
||||
x[2, 2] <- NA
|
||||
x[4, 1] <- NA
|
||||
dm2 <- xgb.DMatrix(x)
|
||||
expect_equal(xgb.get.DMatrix.num.non.missing(dm2), 8)
|
||||
})
|
||||
|
||||
test_that("xgb.DMatrix: retrieving data as CSR", {
|
||||
data(mtcars)
|
||||
dm <- xgb.DMatrix(as.matrix(mtcars))
|
||||
csr <- xgb.get.DMatrix.data(dm)
|
||||
expect_equal(dim(csr), dim(mtcars))
|
||||
expect_equal(colnames(csr), colnames(mtcars))
|
||||
expect_equal(unname(as.matrix(csr)), unname(as.matrix(mtcars)), tolerance = 1e-6)
|
||||
})
|
||||
|
||||
test_that("xgb.DMatrix: quantile cuts look correct", {
|
||||
data(mtcars)
|
||||
y <- mtcars$mpg
|
||||
x <- as.matrix(mtcars[, -1])
|
||||
dm <- xgb.DMatrix(x, label = y)
|
||||
model <- xgb.train(
|
||||
data = dm,
|
||||
params = list(
|
||||
tree_method = "hist",
|
||||
max_bin = 8,
|
||||
nthread = 1
|
||||
),
|
||||
nrounds = 3
|
||||
)
|
||||
qcut_list <- xgb.get.DMatrix.qcut(dm, "list")
|
||||
qcut_arrays <- xgb.get.DMatrix.qcut(dm, "arrays")
|
||||
|
||||
expect_equal(length(qcut_arrays), 2)
|
||||
expect_equal(names(qcut_arrays), c("indptr", "data"))
|
||||
expect_equal(length(qcut_arrays$indptr), ncol(x) + 1)
|
||||
expect_true(min(diff(qcut_arrays$indptr)) > 0)
|
||||
|
||||
col_min <- apply(x, 2, min)
|
||||
col_max <- apply(x, 2, max)
|
||||
|
||||
expect_equal(length(qcut_list), ncol(x))
|
||||
expect_equal(names(qcut_list), colnames(x))
|
||||
lapply(
|
||||
seq(1, ncol(x)),
|
||||
function(col) {
|
||||
cuts <- qcut_list[[col]]
|
||||
expect_true(min(diff(cuts)) > 0)
|
||||
expect_true(col_min[col] > cuts[1])
|
||||
expect_true(col_max[col] < cuts[length(cuts)])
|
||||
expect_true(length(cuts) <= 9)
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user