[R] Implement feature info for DMatrix. (#8048)
This commit is contained in:
@@ -54,7 +54,10 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
|
||||
stop("xgb.DMatrix does not support construction from ", typeof(data))
|
||||
}
|
||||
dmat <- handle
|
||||
attributes(dmat) <- list(.Dimnames = list(NULL, cnames), class = "xgb.DMatrix")
|
||||
attributes(dmat) <- list(class = "xgb.DMatrix")
|
||||
if (!is.null(cnames)) {
|
||||
setinfo(dmat, "feature_name", cnames)
|
||||
}
|
||||
|
||||
info <- append(info, list(...))
|
||||
for (i in seq_along(info)) {
|
||||
@@ -144,7 +147,9 @@ dim.xgb.DMatrix <- function(x) {
|
||||
#' @rdname dimnames.xgb.DMatrix
|
||||
#' @export
|
||||
dimnames.xgb.DMatrix <- function(x) {
|
||||
attr(x, '.Dimnames')
|
||||
fn <- getinfo(x, "feature_name")
|
||||
## row names is null.
|
||||
list(NULL, fn)
|
||||
}
|
||||
|
||||
#' @rdname dimnames.xgb.DMatrix
|
||||
@@ -155,13 +160,13 @@ dimnames.xgb.DMatrix <- function(x) {
|
||||
if (!is.null(value[[1L]]))
|
||||
stop("xgb.DMatrix does not have rownames")
|
||||
if (is.null(value[[2]])) {
|
||||
attr(x, '.Dimnames') <- NULL
|
||||
setinfo(x, "feature_name", NULL)
|
||||
return(x)
|
||||
}
|
||||
if (ncol(x) != length(value[[2]]))
|
||||
stop("can't assign ", length(value[[2]]), " colnames to a ",
|
||||
ncol(x), " column xgb.DMatrix")
|
||||
attr(x, '.Dimnames') <- value
|
||||
if (ncol(x) != length(value[[2]])) {
|
||||
stop("can't assign ", length(value[[2]]), " colnames to a ", ncol(x), " column xgb.DMatrix")
|
||||
}
|
||||
setinfo(x, "feature_name", value[[2]])
|
||||
x
|
||||
}
|
||||
|
||||
@@ -203,13 +208,17 @@ getinfo <- function(object, ...) UseMethod("getinfo")
|
||||
#' @export
|
||||
getinfo.xgb.DMatrix <- function(object, name, ...) {
|
||||
if (typeof(name) != "character" ||
|
||||
length(name) != 1 ||
|
||||
!name %in% c('label', 'weight', 'base_margin', 'nrow',
|
||||
'label_lower_bound', 'label_upper_bound')) {
|
||||
stop("getinfo: name must be one of the following\n",
|
||||
" 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound'")
|
||||
length(name) != 1 ||
|
||||
!name %in% c('label', 'weight', 'base_margin', 'nrow',
|
||||
'label_lower_bound', 'label_upper_bound', "feature_type", "feature_name")) {
|
||||
stop(
|
||||
"getinfo: name must be one of the following\n",
|
||||
" 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound', 'feature_type', 'feature_name'"
|
||||
)
|
||||
}
|
||||
if (name != "nrow"){
|
||||
if (name == "feature_name" || name == "feature_type") {
|
||||
ret <- .Call(XGDMatrixGetStrFeatureInfo_R, object, name)
|
||||
} else if (name != "nrow"){
|
||||
ret <- .Call(XGDMatrixGetInfo_R, object, name)
|
||||
} else {
|
||||
ret <- nrow(object)
|
||||
@@ -294,6 +303,30 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
|
||||
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
|
||||
return(TRUE)
|
||||
}
|
||||
|
||||
set_feat_info <- function(name) {
|
||||
msg <- sprintf(
|
||||
"The number of %s must equal to the number of columns in the input data. %s vs. %s",
|
||||
name,
|
||||
length(info),
|
||||
ncol(object)
|
||||
)
|
||||
if (!is.null(info)) {
|
||||
info <- as.list(info)
|
||||
if (length(info) != ncol(object)) {
|
||||
stop(msg)
|
||||
}
|
||||
}
|
||||
.Call(XGDMatrixSetStrFeatureInfo_R, object, name, info)
|
||||
}
|
||||
if (name == "feature_name") {
|
||||
set_feat_info("feature_name")
|
||||
return(TRUE)
|
||||
}
|
||||
if (name == "feature_type") {
|
||||
set_feat_info("feature_type")
|
||||
return(TRUE)
|
||||
}
|
||||
stop("setinfo: unknown info name ", name)
|
||||
return(FALSE)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user