[R] Implement feature info for DMatrix. (#8048)
This commit is contained in:
parent
701f32b227
commit
210eb471e9
@ -54,7 +54,10 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
|
||||
stop("xgb.DMatrix does not support construction from ", typeof(data))
|
||||
}
|
||||
dmat <- handle
|
||||
attributes(dmat) <- list(.Dimnames = list(NULL, cnames), class = "xgb.DMatrix")
|
||||
attributes(dmat) <- list(class = "xgb.DMatrix")
|
||||
if (!is.null(cnames)) {
|
||||
setinfo(dmat, "feature_name", cnames)
|
||||
}
|
||||
|
||||
info <- append(info, list(...))
|
||||
for (i in seq_along(info)) {
|
||||
@ -144,7 +147,9 @@ dim.xgb.DMatrix <- function(x) {
|
||||
#' @rdname dimnames.xgb.DMatrix
|
||||
#' @export
|
||||
dimnames.xgb.DMatrix <- function(x) {
|
||||
attr(x, '.Dimnames')
|
||||
fn <- getinfo(x, "feature_name")
|
||||
## row names is null.
|
||||
list(NULL, fn)
|
||||
}
|
||||
|
||||
#' @rdname dimnames.xgb.DMatrix
|
||||
@ -155,13 +160,13 @@ dimnames.xgb.DMatrix <- function(x) {
|
||||
if (!is.null(value[[1L]]))
|
||||
stop("xgb.DMatrix does not have rownames")
|
||||
if (is.null(value[[2]])) {
|
||||
attr(x, '.Dimnames') <- NULL
|
||||
setinfo(x, "feature_name", NULL)
|
||||
return(x)
|
||||
}
|
||||
if (ncol(x) != length(value[[2]]))
|
||||
stop("can't assign ", length(value[[2]]), " colnames to a ",
|
||||
ncol(x), " column xgb.DMatrix")
|
||||
attr(x, '.Dimnames') <- value
|
||||
if (ncol(x) != length(value[[2]])) {
|
||||
stop("can't assign ", length(value[[2]]), " colnames to a ", ncol(x), " column xgb.DMatrix")
|
||||
}
|
||||
setinfo(x, "feature_name", value[[2]])
|
||||
x
|
||||
}
|
||||
|
||||
@ -203,13 +208,17 @@ getinfo <- function(object, ...) UseMethod("getinfo")
|
||||
#' @export
|
||||
getinfo.xgb.DMatrix <- function(object, name, ...) {
|
||||
if (typeof(name) != "character" ||
|
||||
length(name) != 1 ||
|
||||
!name %in% c('label', 'weight', 'base_margin', 'nrow',
|
||||
'label_lower_bound', 'label_upper_bound')) {
|
||||
stop("getinfo: name must be one of the following\n",
|
||||
" 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound'")
|
||||
length(name) != 1 ||
|
||||
!name %in% c('label', 'weight', 'base_margin', 'nrow',
|
||||
'label_lower_bound', 'label_upper_bound', "feature_type", "feature_name")) {
|
||||
stop(
|
||||
"getinfo: name must be one of the following\n",
|
||||
" 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound', 'feature_type', 'feature_name'"
|
||||
)
|
||||
}
|
||||
if (name != "nrow"){
|
||||
if (name == "feature_name" || name == "feature_type") {
|
||||
ret <- .Call(XGDMatrixGetStrFeatureInfo_R, object, name)
|
||||
} else if (name != "nrow"){
|
||||
ret <- .Call(XGDMatrixGetInfo_R, object, name)
|
||||
} else {
|
||||
ret <- nrow(object)
|
||||
@ -294,6 +303,30 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
|
||||
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
|
||||
return(TRUE)
|
||||
}
|
||||
|
||||
set_feat_info <- function(name) {
|
||||
msg <- sprintf(
|
||||
"The number of %s must equal to the number of columns in the input data. %s vs. %s",
|
||||
name,
|
||||
length(info),
|
||||
ncol(object)
|
||||
)
|
||||
if (!is.null(info)) {
|
||||
info <- as.list(info)
|
||||
if (length(info) != ncol(object)) {
|
||||
stop(msg)
|
||||
}
|
||||
}
|
||||
.Call(XGDMatrixSetStrFeatureInfo_R, object, name, info)
|
||||
}
|
||||
if (name == "feature_name") {
|
||||
set_feat_info("feature_name")
|
||||
return(TRUE)
|
||||
}
|
||||
if (name == "feature_type") {
|
||||
set_feat_info("feature_type")
|
||||
return(TRUE)
|
||||
}
|
||||
stop("setinfo: unknown info name ", name)
|
||||
return(FALSE)
|
||||
}
|
||||
|
||||
@ -42,10 +42,12 @@ extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
|
||||
extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP);
|
||||
extern SEXP XGDMatrixGetStrFeatureInfo_R(SEXP, SEXP);
|
||||
extern SEXP XGDMatrixNumCol_R(SEXP);
|
||||
extern SEXP XGDMatrixNumRow_R(SEXP);
|
||||
extern SEXP XGDMatrixSaveBinary_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixSetStrFeatureInfo_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP);
|
||||
extern SEXP XGBSetGlobalConfig_R(SEXP);
|
||||
extern SEXP XGBGetGlobalConfig_R();
|
||||
@ -78,10 +80,12 @@ static const R_CallMethodDef CallEntries[] = {
|
||||
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
|
||||
{"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3},
|
||||
{"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2},
|
||||
{"XGDMatrixGetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixGetStrFeatureInfo_R, 2},
|
||||
{"XGDMatrixNumCol_R", (DL_FUNC) &XGDMatrixNumCol_R, 1},
|
||||
{"XGDMatrixNumRow_R", (DL_FUNC) &XGDMatrixNumRow_R, 1},
|
||||
{"XGDMatrixSaveBinary_R", (DL_FUNC) &XGDMatrixSaveBinary_R, 3},
|
||||
{"XGDMatrixSetInfo_R", (DL_FUNC) &XGDMatrixSetInfo_R, 3},
|
||||
{"XGDMatrixSetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixSetStrFeatureInfo_R, 3},
|
||||
{"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 2},
|
||||
{"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1},
|
||||
{"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0},
|
||||
|
||||
@ -249,15 +249,53 @@ XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGDMatrixSetStrFeatureInfo_R(SEXP handle, SEXP field, SEXP array) {
|
||||
R_API_BEGIN();
|
||||
size_t len{0};
|
||||
if (!isNull(array)) {
|
||||
len = length(array);
|
||||
}
|
||||
|
||||
const char *name = CHAR(asChar(field));
|
||||
std::vector<std::string> str_info;
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
str_info.emplace_back(CHAR(asChar(VECTOR_ELT(array, i))));
|
||||
}
|
||||
std::vector<char const*> vec(len);
|
||||
std::transform(str_info.cbegin(), str_info.cend(), vec.begin(),
|
||||
[](auto const &str) { return str.c_str(); });
|
||||
CHECK_CALL(XGDMatrixSetStrFeatureInfo(R_ExternalPtrAddr(handle), name, vec.data(), len));
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGDMatrixGetStrFeatureInfo_R(SEXP handle, SEXP field) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
char const **out_features{nullptr};
|
||||
bst_ulong len{0};
|
||||
const char *name = CHAR(asChar(field));
|
||||
XGDMatrixGetStrFeatureInfo(R_ExternalPtrAddr(handle), name, &len, &out_features);
|
||||
|
||||
if (len > 0) {
|
||||
ret = PROTECT(allocVector(STRSXP, len));
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
SET_STRING_ELT(ret, i, mkChar(out_features[i]));
|
||||
}
|
||||
} else {
|
||||
ret = PROTECT(R_NilValue);
|
||||
}
|
||||
R_API_END();
|
||||
UNPROTECT(1);
|
||||
return ret;
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
bst_ulong olen;
|
||||
const float *res;
|
||||
CHECK_CALL(XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle),
|
||||
CHAR(asChar(field)),
|
||||
&olen,
|
||||
&res));
|
||||
CHECK_CALL(XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), &olen, &res));
|
||||
ret = PROTECT(allocVector(REALSXP, olen));
|
||||
for (size_t i = 0; i < olen; ++i) {
|
||||
REAL(ret)[i] = res[i];
|
||||
|
||||
@ -42,6 +42,20 @@ test_that("xgb.DMatrix: saving, loading", {
|
||||
dtest4 <- xgb.DMatrix(tmp_file, silent = TRUE)
|
||||
expect_equal(dim(dtest4), c(3, 4))
|
||||
expect_equal(getinfo(dtest4, 'label'), c(0, 1, 0))
|
||||
|
||||
# check that feature info is saved
|
||||
data(agaricus.train, package = 'xgboost')
|
||||
dtrain <- xgb.DMatrix(data = agaricus.train$data, label = agaricus.train$label)
|
||||
cnames <- colnames(dtrain)
|
||||
expect_equal(length(cnames), 126)
|
||||
tmp_file <- tempfile('xgb.DMatrix_')
|
||||
xgb.DMatrix.save(dtrain, tmp_file)
|
||||
dtrain <- xgb.DMatrix(tmp_file)
|
||||
expect_equal(colnames(dtrain), cnames)
|
||||
|
||||
ft <- rep(c("c", "q"), each=length(cnames)/2)
|
||||
setinfo(dtrain, "feature_type", ft)
|
||||
expect_equal(ft, getinfo(dtrain, "feature_type"))
|
||||
})
|
||||
|
||||
test_that("xgb.DMatrix: getinfo & setinfo", {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user