[R] Implement feature info for DMatrix. (#8048)

This commit is contained in:
Jiaming Yuan 2022-07-09 05:57:39 +08:00 committed by GitHub
parent 701f32b227
commit 210eb471e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 106 additions and 17 deletions

View File

@ -54,7 +54,10 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
stop("xgb.DMatrix does not support construction from ", typeof(data))
}
dmat <- handle
attributes(dmat) <- list(.Dimnames = list(NULL, cnames), class = "xgb.DMatrix")
attributes(dmat) <- list(class = "xgb.DMatrix")
if (!is.null(cnames)) {
setinfo(dmat, "feature_name", cnames)
}
info <- append(info, list(...))
for (i in seq_along(info)) {
@ -144,7 +147,9 @@ dim.xgb.DMatrix <- function(x) {
#' @rdname dimnames.xgb.DMatrix
#' @export
dimnames.xgb.DMatrix <- function(x) {
attr(x, '.Dimnames')
fn <- getinfo(x, "feature_name")
## row names is null.
list(NULL, fn)
}
#' @rdname dimnames.xgb.DMatrix
@ -155,13 +160,13 @@ dimnames.xgb.DMatrix <- function(x) {
if (!is.null(value[[1L]]))
stop("xgb.DMatrix does not have rownames")
if (is.null(value[[2]])) {
attr(x, '.Dimnames') <- NULL
setinfo(x, "feature_name", NULL)
return(x)
}
if (ncol(x) != length(value[[2]]))
stop("can't assign ", length(value[[2]]), " colnames to a ",
ncol(x), " column xgb.DMatrix")
attr(x, '.Dimnames') <- value
if (ncol(x) != length(value[[2]])) {
stop("can't assign ", length(value[[2]]), " colnames to a ", ncol(x), " column xgb.DMatrix")
}
setinfo(x, "feature_name", value[[2]])
x
}
@ -203,13 +208,17 @@ getinfo <- function(object, ...) UseMethod("getinfo")
#' @export
getinfo.xgb.DMatrix <- function(object, name, ...) {
if (typeof(name) != "character" ||
length(name) != 1 ||
!name %in% c('label', 'weight', 'base_margin', 'nrow',
'label_lower_bound', 'label_upper_bound')) {
stop("getinfo: name must be one of the following\n",
" 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound'")
length(name) != 1 ||
!name %in% c('label', 'weight', 'base_margin', 'nrow',
'label_lower_bound', 'label_upper_bound', "feature_type", "feature_name")) {
stop(
"getinfo: name must be one of the following\n",
" 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound', 'feature_type', 'feature_name'"
)
}
if (name != "nrow"){
if (name == "feature_name" || name == "feature_type") {
ret <- .Call(XGDMatrixGetStrFeatureInfo_R, object, name)
} else if (name != "nrow"){
ret <- .Call(XGDMatrixGetInfo_R, object, name)
} else {
ret <- nrow(object)
@ -294,6 +303,30 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
return(TRUE)
}
set_feat_info <- function(name) {
msg <- sprintf(
"The number of %s must equal to the number of columns in the input data. %s vs. %s",
name,
length(info),
ncol(object)
)
if (!is.null(info)) {
info <- as.list(info)
if (length(info) != ncol(object)) {
stop(msg)
}
}
.Call(XGDMatrixSetStrFeatureInfo_R, object, name, info)
}
if (name == "feature_name") {
set_feat_info("feature_name")
return(TRUE)
}
if (name == "feature_type") {
set_feat_info("feature_type")
return(TRUE)
}
stop("setinfo: unknown info name ", name)
return(FALSE)
}

View File

@ -42,10 +42,12 @@ extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP);
extern SEXP XGDMatrixGetStrFeatureInfo_R(SEXP, SEXP);
extern SEXP XGDMatrixNumCol_R(SEXP);
extern SEXP XGDMatrixNumRow_R(SEXP);
extern SEXP XGDMatrixSaveBinary_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixSetStrFeatureInfo_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP);
extern SEXP XGBSetGlobalConfig_R(SEXP);
extern SEXP XGBGetGlobalConfig_R();
@ -78,10 +80,12 @@ static const R_CallMethodDef CallEntries[] = {
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
{"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3},
{"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2},
{"XGDMatrixGetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixGetStrFeatureInfo_R, 2},
{"XGDMatrixNumCol_R", (DL_FUNC) &XGDMatrixNumCol_R, 1},
{"XGDMatrixNumRow_R", (DL_FUNC) &XGDMatrixNumRow_R, 1},
{"XGDMatrixSaveBinary_R", (DL_FUNC) &XGDMatrixSaveBinary_R, 3},
{"XGDMatrixSetInfo_R", (DL_FUNC) &XGDMatrixSetInfo_R, 3},
{"XGDMatrixSetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixSetStrFeatureInfo_R, 3},
{"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 2},
{"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1},
{"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0},

View File

@ -249,15 +249,53 @@ XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
return R_NilValue;
}
XGB_DLL SEXP XGDMatrixSetStrFeatureInfo_R(SEXP handle, SEXP field, SEXP array) {
R_API_BEGIN();
size_t len{0};
if (!isNull(array)) {
len = length(array);
}
const char *name = CHAR(asChar(field));
std::vector<std::string> str_info;
for (size_t i = 0; i < len; ++i) {
str_info.emplace_back(CHAR(asChar(VECTOR_ELT(array, i))));
}
std::vector<char const*> vec(len);
std::transform(str_info.cbegin(), str_info.cend(), vec.begin(),
[](auto const &str) { return str.c_str(); });
CHECK_CALL(XGDMatrixSetStrFeatureInfo(R_ExternalPtrAddr(handle), name, vec.data(), len));
R_API_END();
return R_NilValue;
}
XGB_DLL SEXP XGDMatrixGetStrFeatureInfo_R(SEXP handle, SEXP field) {
SEXP ret;
R_API_BEGIN();
char const **out_features{nullptr};
bst_ulong len{0};
const char *name = CHAR(asChar(field));
XGDMatrixGetStrFeatureInfo(R_ExternalPtrAddr(handle), name, &len, &out_features);
if (len > 0) {
ret = PROTECT(allocVector(STRSXP, len));
for (size_t i = 0; i < len; ++i) {
SET_STRING_ELT(ret, i, mkChar(out_features[i]));
}
} else {
ret = PROTECT(R_NilValue);
}
R_API_END();
UNPROTECT(1);
return ret;
}
XGB_DLL SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) {
SEXP ret;
R_API_BEGIN();
bst_ulong olen;
const float *res;
CHECK_CALL(XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle),
CHAR(asChar(field)),
&olen,
&res));
CHECK_CALL(XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), &olen, &res));
ret = PROTECT(allocVector(REALSXP, olen));
for (size_t i = 0; i < olen; ++i) {
REAL(ret)[i] = res[i];

View File

@ -42,6 +42,20 @@ test_that("xgb.DMatrix: saving, loading", {
dtest4 <- xgb.DMatrix(tmp_file, silent = TRUE)
expect_equal(dim(dtest4), c(3, 4))
expect_equal(getinfo(dtest4, 'label'), c(0, 1, 0))
# check that feature info is saved
data(agaricus.train, package = 'xgboost')
dtrain <- xgb.DMatrix(data = agaricus.train$data, label = agaricus.train$label)
cnames <- colnames(dtrain)
expect_equal(length(cnames), 126)
tmp_file <- tempfile('xgb.DMatrix_')
xgb.DMatrix.save(dtrain, tmp_file)
dtrain <- xgb.DMatrix(tmp_file)
expect_equal(colnames(dtrain), cnames)
ft <- rep(c("c", "q"), each=length(cnames)/2)
setinfo(dtrain, "feature_type", ft)
expect_equal(ft, getinfo(dtrain, "feature_type"))
})
test_that("xgb.DMatrix: getinfo & setinfo", {