[R] Refactor field logic for dmatrix (#9901)
This commit is contained in:
parent
0edd600f3d
commit
ff3d82c006
@ -28,6 +28,7 @@ export(setinfo)
|
||||
export(slice)
|
||||
export(xgb.Booster.complete)
|
||||
export(xgb.DMatrix)
|
||||
export(xgb.DMatrix.hasinfo)
|
||||
export(xgb.DMatrix.save)
|
||||
export(xgb.attr)
|
||||
export(xgb.attributes)
|
||||
|
||||
@ -163,7 +163,10 @@ xgb.DMatrix <- function(
|
||||
}
|
||||
|
||||
dmat <- handle
|
||||
attributes(dmat) <- list(class = "xgb.DMatrix")
|
||||
attributes(dmat) <- list(
|
||||
class = "xgb.DMatrix",
|
||||
fields = new.env()
|
||||
)
|
||||
|
||||
if (!is.null(label)) {
|
||||
setinfo(dmat, "label", label)
|
||||
@ -199,6 +202,35 @@ xgb.DMatrix <- function(
|
||||
return(dmat)
|
||||
}
|
||||
|
||||
#' @title Check whether DMatrix object has a field
|
||||
#' @description Checks whether an xgb.DMatrix object has a given field assigned to
|
||||
#' it, such as weights, labels, etc.
|
||||
#' @param object The DMatrix object to check for the given \code{info} field.
|
||||
#' @param info The field to check for presence or absence in \code{object}.
|
||||
#' @seealso \link{xgb.DMatrix}, \link{getinfo.xgb.DMatrix}, \link{setinfo.xgb.DMatrix}
|
||||
#' @examples
|
||||
#' library(xgboost)
|
||||
#' x <- matrix(1:10, nrow = 5)
|
||||
#' dm <- xgb.DMatrix(x, nthread = 1)
|
||||
#'
|
||||
#' # 'dm' so far doesn't have any fields set
|
||||
#' xgb.DMatrix.hasinfo(dm, "label")
|
||||
#'
|
||||
#' # Fields can be added after construction
|
||||
#' setinfo(dm, "label", 1:5)
|
||||
#' xgb.DMatrix.hasinfo(dm, "label")
|
||||
#' @export
|
||||
xgb.DMatrix.hasinfo <- function(object, info) {
|
||||
if (!inherits(object, "xgb.DMatrix")) {
|
||||
stop("Object is not an 'xgb.DMatrix'.")
|
||||
}
|
||||
if (.Call(XGCheckNullPtr_R, object)) {
|
||||
warning("xgb.DMatrix object is invalid. Must be constructed again.")
|
||||
return(FALSE)
|
||||
}
|
||||
return(NVL(attr(object, "fields")[[info]], FALSE))
|
||||
}
|
||||
|
||||
|
||||
# get dmatrix from data, label
|
||||
# internal helper method
|
||||
@ -389,7 +421,7 @@ getinfo.xgb.DMatrix <- function(object, name, ...) {
|
||||
#' @param object Object of class "xgb.DMatrix"
|
||||
#' @param name the name of the field to get
|
||||
#' @param info the specific field of information to set
|
||||
#' @param ... other parameters
|
||||
#' @param ... Not used.
|
||||
#'
|
||||
#' @details
|
||||
#' See the documentation for \link{xgb.DMatrix} for possible fields that can be set
|
||||
@ -418,6 +450,12 @@ setinfo <- function(object, ...) UseMethod("setinfo")
|
||||
#' @rdname setinfo
|
||||
#' @export
|
||||
setinfo.xgb.DMatrix <- function(object, name, info, ...) {
|
||||
.internal.setinfo.xgb.DMatrix(object, name, info, ...)
|
||||
attr(object, "fields")[[name]] <- TRUE
|
||||
return(TRUE)
|
||||
}
|
||||
|
||||
.internal.setinfo.xgb.DMatrix <- function(object, name, info, ...) {
|
||||
if (name == "label") {
|
||||
if (NROW(info) != nrow(object))
|
||||
stop("The length of labels must equal to the number of rows in the input data")
|
||||
@ -425,19 +463,19 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
|
||||
return(TRUE)
|
||||
}
|
||||
if (name == "label_lower_bound") {
|
||||
if (length(info) != nrow(object))
|
||||
if (NROW(info) != nrow(object))
|
||||
stop("The length of lower-bound labels must equal to the number of rows in the input data")
|
||||
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
|
||||
.Call(XGDMatrixSetInfo_R, object, name, info)
|
||||
return(TRUE)
|
||||
}
|
||||
if (name == "label_upper_bound") {
|
||||
if (length(info) != nrow(object))
|
||||
if (NROW(info) != nrow(object))
|
||||
stop("The length of upper-bound labels must equal to the number of rows in the input data")
|
||||
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
|
||||
.Call(XGDMatrixSetInfo_R, object, name, info)
|
||||
return(TRUE)
|
||||
}
|
||||
if (name == "weight") {
|
||||
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
|
||||
.Call(XGDMatrixSetInfo_R, object, name, info)
|
||||
return(TRUE)
|
||||
}
|
||||
if (name == "base_margin") {
|
||||
@ -447,20 +485,20 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
|
||||
if (name == "group") {
|
||||
if (sum(info) != nrow(object))
|
||||
stop("The sum of groups must equal to the number of rows in the input data")
|
||||
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info))
|
||||
.Call(XGDMatrixSetInfo_R, object, name, info)
|
||||
return(TRUE)
|
||||
}
|
||||
if (name == "qid") {
|
||||
if (NROW(info) != nrow(object))
|
||||
stop("The length of qid assignments must equal to the number of rows in the input data")
|
||||
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info))
|
||||
.Call(XGDMatrixSetInfo_R, object, name, info)
|
||||
return(TRUE)
|
||||
}
|
||||
if (name == "feature_weights") {
|
||||
if (length(info) != ncol(object)) {
|
||||
if (NROW(info) != ncol(object)) {
|
||||
stop("The number of feature weights must equal to the number of columns in the input data")
|
||||
}
|
||||
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
|
||||
.Call(XGDMatrixSetInfo_R, object, name, info)
|
||||
return(TRUE)
|
||||
}
|
||||
|
||||
@ -568,11 +606,15 @@ slice.xgb.DMatrix <- function(object, idxset, ...) {
|
||||
#' @method print xgb.DMatrix
|
||||
#' @export
|
||||
print.xgb.DMatrix <- function(x, verbose = FALSE, ...) {
|
||||
if (.Call(XGCheckNullPtr_R, x)) {
|
||||
cat("INVALID xgb.DMatrix object. Must be constructed anew.\n")
|
||||
return(invisible(x))
|
||||
}
|
||||
cat('xgb.DMatrix dim:', nrow(x), 'x', ncol(x), ' info: ')
|
||||
infos <- character(0)
|
||||
if (length(getinfo(x, 'label')) > 0) infos <- 'label'
|
||||
if (length(getinfo(x, 'weight')) > 0) infos <- c(infos, 'weight')
|
||||
if (length(getinfo(x, 'base_margin')) > 0) infos <- c(infos, 'base_margin')
|
||||
if (xgb.DMatrix.hasinfo(x, 'label')) infos <- 'label'
|
||||
if (xgb.DMatrix.hasinfo(x, 'weight')) infos <- c(infos, 'weight')
|
||||
if (xgb.DMatrix.hasinfo(x, 'base_margin')) infos <- c(infos, 'base_margin')
|
||||
if (length(infos) == 0) infos <- 'NA'
|
||||
cat(infos)
|
||||
cnames <- colnames(x)
|
||||
|
||||
@ -126,6 +126,9 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
|
||||
early_stopping_rounds = NULL, maximize = NULL, callbacks = list(), ...) {
|
||||
|
||||
check.deprecation(...)
|
||||
if (inherits(data, "xgb.DMatrix") && .Call(XGCheckNullPtr_R, data)) {
|
||||
stop("'data' is an invalid 'xgb.DMatrix' object. Must be constructed again.")
|
||||
}
|
||||
|
||||
params <- check.booster.params(params, ...)
|
||||
# TODO: should we deprecate the redundant 'metrics' parameter?
|
||||
@ -136,7 +139,7 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
|
||||
check.custom.eval()
|
||||
|
||||
# Check the labels
|
||||
if ((inherits(data, 'xgb.DMatrix') && is.null(getinfo(data, 'label'))) ||
|
||||
if ((inherits(data, 'xgb.DMatrix') && !xgb.DMatrix.hasinfo(data, 'label')) ||
|
||||
(!inherits(data, 'xgb.DMatrix') && is.null(label))) {
|
||||
stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix")
|
||||
} else if (inherits(data, 'xgb.DMatrix')) {
|
||||
|
||||
@ -12,7 +12,7 @@ setinfo(object, ...)
|
||||
\arguments{
|
||||
\item{object}{Object of class "xgb.DMatrix"}
|
||||
|
||||
\item{...}{other parameters}
|
||||
\item{...}{Not used.}
|
||||
|
||||
\item{name}{the name of the field to get}
|
||||
|
||||
|
||||
32
R-package/man/xgb.DMatrix.hasinfo.Rd
Normal file
32
R-package/man/xgb.DMatrix.hasinfo.Rd
Normal file
@ -0,0 +1,32 @@
|
||||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/xgb.DMatrix.R
|
||||
\name{xgb.DMatrix.hasinfo}
|
||||
\alias{xgb.DMatrix.hasinfo}
|
||||
\title{Check whether DMatrix object has a field}
|
||||
\usage{
|
||||
xgb.DMatrix.hasinfo(object, info)
|
||||
}
|
||||
\arguments{
|
||||
\item{object}{The DMatrix object to check for the given \code{info} field.}
|
||||
|
||||
\item{info}{The field to check for presence or absence in \code{object}.}
|
||||
}
|
||||
\description{
|
||||
Checks whether an xgb.DMatrix object has a given field assigned to
|
||||
it, such as weights, labels, etc.
|
||||
}
|
||||
\examples{
|
||||
library(xgboost)
|
||||
x <- matrix(1:10, nrow = 5)
|
||||
dm <- xgb.DMatrix(x, nthread = 1)
|
||||
|
||||
# 'dm' so far doesn't have any fields set
|
||||
xgb.DMatrix.hasinfo(dm, "label")
|
||||
|
||||
# Fields can be added after construction
|
||||
setinfo(dm, "label", 1:5)
|
||||
xgb.DMatrix.hasinfo(dm, "label")
|
||||
}
|
||||
\seealso{
|
||||
\link{xgb.DMatrix}, \link{getinfo.xgb.DMatrix}, \link{setinfo.xgb.DMatrix}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user