[R] Implement feature info for DMatrix. (#8048)
This commit is contained in:
parent
701f32b227
commit
210eb471e9
@ -54,7 +54,10 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
|
|||||||
stop("xgb.DMatrix does not support construction from ", typeof(data))
|
stop("xgb.DMatrix does not support construction from ", typeof(data))
|
||||||
}
|
}
|
||||||
dmat <- handle
|
dmat <- handle
|
||||||
attributes(dmat) <- list(.Dimnames = list(NULL, cnames), class = "xgb.DMatrix")
|
attributes(dmat) <- list(class = "xgb.DMatrix")
|
||||||
|
if (!is.null(cnames)) {
|
||||||
|
setinfo(dmat, "feature_name", cnames)
|
||||||
|
}
|
||||||
|
|
||||||
info <- append(info, list(...))
|
info <- append(info, list(...))
|
||||||
for (i in seq_along(info)) {
|
for (i in seq_along(info)) {
|
||||||
@ -144,7 +147,9 @@ dim.xgb.DMatrix <- function(x) {
|
|||||||
#' @rdname dimnames.xgb.DMatrix
|
#' @rdname dimnames.xgb.DMatrix
|
||||||
#' @export
|
#' @export
|
||||||
dimnames.xgb.DMatrix <- function(x) {
|
dimnames.xgb.DMatrix <- function(x) {
|
||||||
attr(x, '.Dimnames')
|
fn <- getinfo(x, "feature_name")
|
||||||
|
## row names is null.
|
||||||
|
list(NULL, fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
#' @rdname dimnames.xgb.DMatrix
|
#' @rdname dimnames.xgb.DMatrix
|
||||||
@ -155,13 +160,13 @@ dimnames.xgb.DMatrix <- function(x) {
|
|||||||
if (!is.null(value[[1L]]))
|
if (!is.null(value[[1L]]))
|
||||||
stop("xgb.DMatrix does not have rownames")
|
stop("xgb.DMatrix does not have rownames")
|
||||||
if (is.null(value[[2]])) {
|
if (is.null(value[[2]])) {
|
||||||
attr(x, '.Dimnames') <- NULL
|
setinfo(x, "feature_name", NULL)
|
||||||
return(x)
|
return(x)
|
||||||
}
|
}
|
||||||
if (ncol(x) != length(value[[2]]))
|
if (ncol(x) != length(value[[2]])) {
|
||||||
stop("can't assign ", length(value[[2]]), " colnames to a ",
|
stop("can't assign ", length(value[[2]]), " colnames to a ", ncol(x), " column xgb.DMatrix")
|
||||||
ncol(x), " column xgb.DMatrix")
|
}
|
||||||
attr(x, '.Dimnames') <- value
|
setinfo(x, "feature_name", value[[2]])
|
||||||
x
|
x
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -205,11 +210,15 @@ getinfo.xgb.DMatrix <- function(object, name, ...) {
|
|||||||
if (typeof(name) != "character" ||
|
if (typeof(name) != "character" ||
|
||||||
length(name) != 1 ||
|
length(name) != 1 ||
|
||||||
!name %in% c('label', 'weight', 'base_margin', 'nrow',
|
!name %in% c('label', 'weight', 'base_margin', 'nrow',
|
||||||
'label_lower_bound', 'label_upper_bound')) {
|
'label_lower_bound', 'label_upper_bound', "feature_type", "feature_name")) {
|
||||||
stop("getinfo: name must be one of the following\n",
|
stop(
|
||||||
" 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound'")
|
"getinfo: name must be one of the following\n",
|
||||||
|
" 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound', 'feature_type', 'feature_name'"
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if (name != "nrow"){
|
if (name == "feature_name" || name == "feature_type") {
|
||||||
|
ret <- .Call(XGDMatrixGetStrFeatureInfo_R, object, name)
|
||||||
|
} else if (name != "nrow"){
|
||||||
ret <- .Call(XGDMatrixGetInfo_R, object, name)
|
ret <- .Call(XGDMatrixGetInfo_R, object, name)
|
||||||
} else {
|
} else {
|
||||||
ret <- nrow(object)
|
ret <- nrow(object)
|
||||||
@ -294,6 +303,30 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
|
|||||||
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
|
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
|
||||||
return(TRUE)
|
return(TRUE)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
set_feat_info <- function(name) {
|
||||||
|
msg <- sprintf(
|
||||||
|
"The number of %s must equal to the number of columns in the input data. %s vs. %s",
|
||||||
|
name,
|
||||||
|
length(info),
|
||||||
|
ncol(object)
|
||||||
|
)
|
||||||
|
if (!is.null(info)) {
|
||||||
|
info <- as.list(info)
|
||||||
|
if (length(info) != ncol(object)) {
|
||||||
|
stop(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.Call(XGDMatrixSetStrFeatureInfo_R, object, name, info)
|
||||||
|
}
|
||||||
|
if (name == "feature_name") {
|
||||||
|
set_feat_info("feature_name")
|
||||||
|
return(TRUE)
|
||||||
|
}
|
||||||
|
if (name == "feature_type") {
|
||||||
|
set_feat_info("feature_type")
|
||||||
|
return(TRUE)
|
||||||
|
}
|
||||||
stop("setinfo: unknown info name ", name)
|
stop("setinfo: unknown info name ", name)
|
||||||
return(FALSE)
|
return(FALSE)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -42,10 +42,12 @@ extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP);
|
|||||||
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
|
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP);
|
extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP);
|
extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP);
|
||||||
|
extern SEXP XGDMatrixGetStrFeatureInfo_R(SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixNumCol_R(SEXP);
|
extern SEXP XGDMatrixNumCol_R(SEXP);
|
||||||
extern SEXP XGDMatrixNumRow_R(SEXP);
|
extern SEXP XGDMatrixNumRow_R(SEXP);
|
||||||
extern SEXP XGDMatrixSaveBinary_R(SEXP, SEXP, SEXP);
|
extern SEXP XGDMatrixSaveBinary_R(SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP);
|
extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP);
|
||||||
|
extern SEXP XGDMatrixSetStrFeatureInfo_R(SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP);
|
extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP);
|
||||||
extern SEXP XGBSetGlobalConfig_R(SEXP);
|
extern SEXP XGBSetGlobalConfig_R(SEXP);
|
||||||
extern SEXP XGBGetGlobalConfig_R();
|
extern SEXP XGBGetGlobalConfig_R();
|
||||||
@ -78,10 +80,12 @@ static const R_CallMethodDef CallEntries[] = {
|
|||||||
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
|
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
|
||||||
{"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3},
|
{"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3},
|
||||||
{"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2},
|
{"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2},
|
||||||
|
{"XGDMatrixGetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixGetStrFeatureInfo_R, 2},
|
||||||
{"XGDMatrixNumCol_R", (DL_FUNC) &XGDMatrixNumCol_R, 1},
|
{"XGDMatrixNumCol_R", (DL_FUNC) &XGDMatrixNumCol_R, 1},
|
||||||
{"XGDMatrixNumRow_R", (DL_FUNC) &XGDMatrixNumRow_R, 1},
|
{"XGDMatrixNumRow_R", (DL_FUNC) &XGDMatrixNumRow_R, 1},
|
||||||
{"XGDMatrixSaveBinary_R", (DL_FUNC) &XGDMatrixSaveBinary_R, 3},
|
{"XGDMatrixSaveBinary_R", (DL_FUNC) &XGDMatrixSaveBinary_R, 3},
|
||||||
{"XGDMatrixSetInfo_R", (DL_FUNC) &XGDMatrixSetInfo_R, 3},
|
{"XGDMatrixSetInfo_R", (DL_FUNC) &XGDMatrixSetInfo_R, 3},
|
||||||
|
{"XGDMatrixSetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixSetStrFeatureInfo_R, 3},
|
||||||
{"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 2},
|
{"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 2},
|
||||||
{"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1},
|
{"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1},
|
||||||
{"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0},
|
{"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0},
|
||||||
|
|||||||
@ -249,15 +249,53 @@ XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
|
|||||||
return R_NilValue;
|
return R_NilValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XGB_DLL SEXP XGDMatrixSetStrFeatureInfo_R(SEXP handle, SEXP field, SEXP array) {
|
||||||
|
R_API_BEGIN();
|
||||||
|
size_t len{0};
|
||||||
|
if (!isNull(array)) {
|
||||||
|
len = length(array);
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *name = CHAR(asChar(field));
|
||||||
|
std::vector<std::string> str_info;
|
||||||
|
for (size_t i = 0; i < len; ++i) {
|
||||||
|
str_info.emplace_back(CHAR(asChar(VECTOR_ELT(array, i))));
|
||||||
|
}
|
||||||
|
std::vector<char const*> vec(len);
|
||||||
|
std::transform(str_info.cbegin(), str_info.cend(), vec.begin(),
|
||||||
|
[](auto const &str) { return str.c_str(); });
|
||||||
|
CHECK_CALL(XGDMatrixSetStrFeatureInfo(R_ExternalPtrAddr(handle), name, vec.data(), len));
|
||||||
|
R_API_END();
|
||||||
|
return R_NilValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL SEXP XGDMatrixGetStrFeatureInfo_R(SEXP handle, SEXP field) {
|
||||||
|
SEXP ret;
|
||||||
|
R_API_BEGIN();
|
||||||
|
char const **out_features{nullptr};
|
||||||
|
bst_ulong len{0};
|
||||||
|
const char *name = CHAR(asChar(field));
|
||||||
|
XGDMatrixGetStrFeatureInfo(R_ExternalPtrAddr(handle), name, &len, &out_features);
|
||||||
|
|
||||||
|
if (len > 0) {
|
||||||
|
ret = PROTECT(allocVector(STRSXP, len));
|
||||||
|
for (size_t i = 0; i < len; ++i) {
|
||||||
|
SET_STRING_ELT(ret, i, mkChar(out_features[i]));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ret = PROTECT(R_NilValue);
|
||||||
|
}
|
||||||
|
R_API_END();
|
||||||
|
UNPROTECT(1);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
XGB_DLL SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) {
|
XGB_DLL SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) {
|
||||||
SEXP ret;
|
SEXP ret;
|
||||||
R_API_BEGIN();
|
R_API_BEGIN();
|
||||||
bst_ulong olen;
|
bst_ulong olen;
|
||||||
const float *res;
|
const float *res;
|
||||||
CHECK_CALL(XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle),
|
CHECK_CALL(XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), &olen, &res));
|
||||||
CHAR(asChar(field)),
|
|
||||||
&olen,
|
|
||||||
&res));
|
|
||||||
ret = PROTECT(allocVector(REALSXP, olen));
|
ret = PROTECT(allocVector(REALSXP, olen));
|
||||||
for (size_t i = 0; i < olen; ++i) {
|
for (size_t i = 0; i < olen; ++i) {
|
||||||
REAL(ret)[i] = res[i];
|
REAL(ret)[i] = res[i];
|
||||||
|
|||||||
@ -42,6 +42,20 @@ test_that("xgb.DMatrix: saving, loading", {
|
|||||||
dtest4 <- xgb.DMatrix(tmp_file, silent = TRUE)
|
dtest4 <- xgb.DMatrix(tmp_file, silent = TRUE)
|
||||||
expect_equal(dim(dtest4), c(3, 4))
|
expect_equal(dim(dtest4), c(3, 4))
|
||||||
expect_equal(getinfo(dtest4, 'label'), c(0, 1, 0))
|
expect_equal(getinfo(dtest4, 'label'), c(0, 1, 0))
|
||||||
|
|
||||||
|
# check that feature info is saved
|
||||||
|
data(agaricus.train, package = 'xgboost')
|
||||||
|
dtrain <- xgb.DMatrix(data = agaricus.train$data, label = agaricus.train$label)
|
||||||
|
cnames <- colnames(dtrain)
|
||||||
|
expect_equal(length(cnames), 126)
|
||||||
|
tmp_file <- tempfile('xgb.DMatrix_')
|
||||||
|
xgb.DMatrix.save(dtrain, tmp_file)
|
||||||
|
dtrain <- xgb.DMatrix(tmp_file)
|
||||||
|
expect_equal(colnames(dtrain), cnames)
|
||||||
|
|
||||||
|
ft <- rep(c("c", "q"), each=length(cnames)/2)
|
||||||
|
setinfo(dtrain, "feature_type", ft)
|
||||||
|
expect_equal(ft, getinfo(dtrain, "feature_type"))
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("xgb.DMatrix: getinfo & setinfo", {
|
test_that("xgb.DMatrix: getinfo & setinfo", {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user