[R] Refactor field logic for dmatrix (#9901)

This commit is contained in:
david-cortes 2023-12-18 13:31:01 +01:00 committed by GitHub
parent 0edd600f3d
commit ff3d82c006
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 94 additions and 16 deletions

View File

@ -28,6 +28,7 @@ export(setinfo)
export(slice)
export(xgb.Booster.complete)
export(xgb.DMatrix)
export(xgb.DMatrix.hasinfo)
export(xgb.DMatrix.save)
export(xgb.attr)
export(xgb.attributes)

View File

@ -163,7 +163,10 @@ xgb.DMatrix <- function(
}
dmat <- handle
attributes(dmat) <- list(class = "xgb.DMatrix")
attributes(dmat) <- list(
class = "xgb.DMatrix",
fields = new.env()
)
if (!is.null(label)) {
setinfo(dmat, "label", label)
@ -199,6 +202,35 @@ xgb.DMatrix <- function(
return(dmat)
}
#' @title Check whether DMatrix object has a field
#' @description Checks whether an xgb.DMatrix object has a given field assigned to
#' it, such as weights, labels, etc.
#' @param object The DMatrix object to check for the given \code{info} field.
#' @param info The field to check for presence or absence in \code{object}.
#' @seealso \link{xgb.DMatrix}, \link{getinfo.xgb.DMatrix}, \link{setinfo.xgb.DMatrix}
#' @examples
#' library(xgboost)
#' x <- matrix(1:10, nrow = 5)
#' dm <- xgb.DMatrix(x, nthread = 1)
#'
#' # 'dm' so far doesn't have any fields set
#' xgb.DMatrix.hasinfo(dm, "label")
#'
#' # Fields can be added after construction
#' setinfo(dm, "label", 1:5)
#' xgb.DMatrix.hasinfo(dm, "label")
#' @export
xgb.DMatrix.hasinfo <- function(object, info) {
if (!inherits(object, "xgb.DMatrix")) {
stop("Object is not an 'xgb.DMatrix'.")
}
if (.Call(XGCheckNullPtr_R, object)) {
warning("xgb.DMatrix object is invalid. Must be constructed again.")
return(FALSE)
}
return(NVL(attr(object, "fields")[[info]], FALSE))
}
# get dmatrix from data, label
# internal helper method
@ -389,7 +421,7 @@ getinfo.xgb.DMatrix <- function(object, name, ...) {
#' @param object Object of class "xgb.DMatrix"
#' @param name the name of the field to get
#' @param info the specific field of information to set
#' @param ... other parameters
#' @param ... Not used.
#'
#' @details
#' See the documentation for \link{xgb.DMatrix} for possible fields that can be set
@ -418,6 +450,12 @@ setinfo <- function(object, ...) UseMethod("setinfo")
#' @rdname setinfo
#' @export
setinfo.xgb.DMatrix <- function(object, name, info, ...) {
.internal.setinfo.xgb.DMatrix(object, name, info, ...)
attr(object, "fields")[[name]] <- TRUE
return(TRUE)
}
.internal.setinfo.xgb.DMatrix <- function(object, name, info, ...) {
if (name == "label") {
if (NROW(info) != nrow(object))
stop("The length of labels must equal to the number of rows in the input data")
@ -425,19 +463,19 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
return(TRUE)
}
if (name == "label_lower_bound") {
if (length(info) != nrow(object))
if (NROW(info) != nrow(object))
stop("The length of lower-bound labels must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
.Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE)
}
if (name == "label_upper_bound") {
if (length(info) != nrow(object))
if (NROW(info) != nrow(object))
stop("The length of upper-bound labels must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
.Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE)
}
if (name == "weight") {
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
.Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE)
}
if (name == "base_margin") {
@ -447,20 +485,20 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
if (name == "group") {
if (sum(info) != nrow(object))
stop("The sum of groups must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info))
.Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE)
}
if (name == "qid") {
if (NROW(info) != nrow(object))
stop("The length of qid assignments must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info))
.Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE)
}
if (name == "feature_weights") {
if (length(info) != ncol(object)) {
if (NROW(info) != ncol(object)) {
stop("The number of feature weights must equal to the number of columns in the input data")
}
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
.Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE)
}
@ -568,11 +606,15 @@ slice.xgb.DMatrix <- function(object, idxset, ...) {
#' @method print xgb.DMatrix
#' @export
print.xgb.DMatrix <- function(x, verbose = FALSE, ...) {
if (.Call(XGCheckNullPtr_R, x)) {
cat("INVALID xgb.DMatrix object. Must be constructed anew.\n")
return(invisible(x))
}
cat('xgb.DMatrix dim:', nrow(x), 'x', ncol(x), ' info: ')
infos <- character(0)
if (length(getinfo(x, 'label')) > 0) infos <- 'label'
if (length(getinfo(x, 'weight')) > 0) infos <- c(infos, 'weight')
if (length(getinfo(x, 'base_margin')) > 0) infos <- c(infos, 'base_margin')
if (xgb.DMatrix.hasinfo(x, 'label')) infos <- 'label'
if (xgb.DMatrix.hasinfo(x, 'weight')) infos <- c(infos, 'weight')
if (xgb.DMatrix.hasinfo(x, 'base_margin')) infos <- c(infos, 'base_margin')
if (length(infos) == 0) infos <- 'NA'
cat(infos)
cnames <- colnames(x)

View File

@ -126,6 +126,9 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
early_stopping_rounds = NULL, maximize = NULL, callbacks = list(), ...) {
check.deprecation(...)
if (inherits(data, "xgb.DMatrix") && .Call(XGCheckNullPtr_R, data)) {
stop("'data' is an invalid 'xgb.DMatrix' object. Must be constructed again.")
}
params <- check.booster.params(params, ...)
# TODO: should we deprecate the redundant 'metrics' parameter?
@ -136,7 +139,7 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
check.custom.eval()
# Check the labels
if ((inherits(data, 'xgb.DMatrix') && is.null(getinfo(data, 'label'))) ||
if ((inherits(data, 'xgb.DMatrix') && !xgb.DMatrix.hasinfo(data, 'label')) ||
(!inherits(data, 'xgb.DMatrix') && is.null(label))) {
stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix")
} else if (inherits(data, 'xgb.DMatrix')) {

View File

@ -12,7 +12,7 @@ setinfo(object, ...)
\arguments{
\item{object}{Object of class "xgb.DMatrix"}
\item{...}{other parameters}
\item{...}{Not used.}
\item{name}{the name of the field to get}

View File

@ -0,0 +1,32 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/xgb.DMatrix.R
\name{xgb.DMatrix.hasinfo}
\alias{xgb.DMatrix.hasinfo}
\title{Check whether DMatrix object has a field}
\usage{
xgb.DMatrix.hasinfo(object, info)
}
\arguments{
\item{object}{The DMatrix object to check for the given \code{info} field.}
\item{info}{The field to check for presence or absence in \code{object}.}
}
\description{
Checks whether an xgb.DMatrix object has a given field assigned to
it, such as weights, labels, etc.
}
\examples{
library(xgboost)
x <- matrix(1:10, nrow = 5)
dm <- xgb.DMatrix(x, nthread = 1)
# 'dm' so far doesn't have any fields set
xgb.DMatrix.hasinfo(dm, "label")
# Fields can be added after construction
setinfo(dm, "label", 1:5)
xgb.DMatrix.hasinfo(dm, "label")
}
\seealso{
\link{xgb.DMatrix}, \link{getinfo.xgb.DMatrix}, \link{setinfo.xgb.DMatrix}
}