[R] Error out on multidimensional arrays (#9852)

This commit is contained in:
david-cortes 2023-12-06 10:43:51 +01:00 committed by GitHub
parent 62571b79eb
commit 0716c64ef7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 0 deletions

View File

@ -50,6 +50,9 @@ SEXP SafeMkChar(const char *c_str, SEXP continuation_token) {
[[nodiscard]] std::string MakeArrayInterfaceFromRMat(SEXP R_mat) { [[nodiscard]] std::string MakeArrayInterfaceFromRMat(SEXP R_mat) {
SEXP mat_dims = Rf_getAttrib(R_mat, R_DimSymbol); SEXP mat_dims = Rf_getAttrib(R_mat, R_DimSymbol);
if (Rf_xlength(mat_dims) > 2) {
LOG(FATAL) << "Passed input array with more than two dimensions, which is not supported.";
}
const int *ptr_mat_dims = INTEGER(mat_dims); const int *ptr_mat_dims = INTEGER(mat_dims);
// Lambda for type dispatch. // Lambda for type dispatch.

View File

@ -297,3 +297,11 @@ test_that("xgb.DMatrix: Inf as missing", {
file.remove("inf.dmatrix") file.remove("inf.dmatrix")
file.remove("nan.dmatrix") file.remove("nan.dmatrix")
}) })
test_that("xgb.DMatrix: error on three-dimensional array", {
set.seed(123)
x <- matrix(rnorm(500), nrow = 50)
y <- rnorm(400)
dim(y) <- c(50, 4, 2)
expect_error(xgb.DMatrix(data = x, label = y))
})