[R] Refactor field logic for dmatrix (#9901)

This commit is contained in:
david-cortes 2023-12-18 13:31:01 +01:00 committed by GitHub
parent 0edd600f3d
commit ff3d82c006
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 94 additions and 16 deletions

View File

@ -28,6 +28,7 @@ export(setinfo)
export(slice) export(slice)
export(xgb.Booster.complete) export(xgb.Booster.complete)
export(xgb.DMatrix) export(xgb.DMatrix)
export(xgb.DMatrix.hasinfo)
export(xgb.DMatrix.save) export(xgb.DMatrix.save)
export(xgb.attr) export(xgb.attr)
export(xgb.attributes) export(xgb.attributes)

View File

@ -163,7 +163,10 @@ xgb.DMatrix <- function(
} }
dmat <- handle dmat <- handle
attributes(dmat) <- list(class = "xgb.DMatrix") attributes(dmat) <- list(
class = "xgb.DMatrix",
fields = new.env()
)
if (!is.null(label)) { if (!is.null(label)) {
setinfo(dmat, "label", label) setinfo(dmat, "label", label)
@ -199,6 +202,35 @@ xgb.DMatrix <- function(
return(dmat) return(dmat)
} }
#' @title Check whether DMatrix object has a field
#' @description Checks whether an xgb.DMatrix object has a given field assigned to
#' it, such as weights, labels, etc.
#' @param object The DMatrix object to check for the given \code{info} field.
#' @param info The field to check for presence or absence in \code{object}.
#' @seealso \link{xgb.DMatrix}, \link{getinfo.xgb.DMatrix}, \link{setinfo.xgb.DMatrix}
#' @examples
#' library(xgboost)
#' x <- matrix(1:10, nrow = 5)
#' dm <- xgb.DMatrix(x, nthread = 1)
#'
#' # 'dm' so far doesn't have any fields set
#' xgb.DMatrix.hasinfo(dm, "label")
#'
#' # Fields can be added after construction
#' setinfo(dm, "label", 1:5)
#' xgb.DMatrix.hasinfo(dm, "label")
#' @export
xgb.DMatrix.hasinfo <- function(object, info) {
if (!inherits(object, "xgb.DMatrix")) {
stop("Object is not an 'xgb.DMatrix'.")
}
if (.Call(XGCheckNullPtr_R, object)) {
warning("xgb.DMatrix object is invalid. Must be constructed again.")
return(FALSE)
}
return(NVL(attr(object, "fields")[[info]], FALSE))
}
# get dmatrix from data, label # get dmatrix from data, label
# internal helper method # internal helper method
@ -389,7 +421,7 @@ getinfo.xgb.DMatrix <- function(object, name, ...) {
#' @param object Object of class "xgb.DMatrix" #' @param object Object of class "xgb.DMatrix"
#' @param name the name of the field to get #' @param name the name of the field to get
#' @param info the specific field of information to set #' @param info the specific field of information to set
#' @param ... other parameters #' @param ... Not used.
#' #'
#' @details #' @details
#' See the documentation for \link{xgb.DMatrix} for possible fields that can be set #' See the documentation for \link{xgb.DMatrix} for possible fields that can be set
@ -418,6 +450,12 @@ setinfo <- function(object, ...) UseMethod("setinfo")
#' @rdname setinfo #' @rdname setinfo
#' @export #' @export
setinfo.xgb.DMatrix <- function(object, name, info, ...) { setinfo.xgb.DMatrix <- function(object, name, info, ...) {
.internal.setinfo.xgb.DMatrix(object, name, info, ...)
attr(object, "fields")[[name]] <- TRUE
return(TRUE)
}
.internal.setinfo.xgb.DMatrix <- function(object, name, info, ...) {
if (name == "label") { if (name == "label") {
if (NROW(info) != nrow(object)) if (NROW(info) != nrow(object))
stop("The length of labels must equal to the number of rows in the input data") stop("The length of labels must equal to the number of rows in the input data")
@ -425,19 +463,19 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
return(TRUE) return(TRUE)
} }
if (name == "label_lower_bound") { if (name == "label_lower_bound") {
if (length(info) != nrow(object)) if (NROW(info) != nrow(object))
stop("The length of lower-bound labels must equal to the number of rows in the input data") stop("The length of lower-bound labels must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info)) .Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE) return(TRUE)
} }
if (name == "label_upper_bound") { if (name == "label_upper_bound") {
if (length(info) != nrow(object)) if (NROW(info) != nrow(object))
stop("The length of upper-bound labels must equal to the number of rows in the input data") stop("The length of upper-bound labels must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info)) .Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE) return(TRUE)
} }
if (name == "weight") { if (name == "weight") {
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info)) .Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE) return(TRUE)
} }
if (name == "base_margin") { if (name == "base_margin") {
@ -447,20 +485,20 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
if (name == "group") { if (name == "group") {
if (sum(info) != nrow(object)) if (sum(info) != nrow(object))
stop("The sum of groups must equal to the number of rows in the input data") stop("The sum of groups must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info)) .Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE) return(TRUE)
} }
if (name == "qid") { if (name == "qid") {
if (NROW(info) != nrow(object)) if (NROW(info) != nrow(object))
stop("The length of qid assignments must equal to the number of rows in the input data") stop("The length of qid assignments must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info)) .Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE) return(TRUE)
} }
if (name == "feature_weights") { if (name == "feature_weights") {
if (length(info) != ncol(object)) { if (NROW(info) != ncol(object)) {
stop("The number of feature weights must equal to the number of columns in the input data") stop("The number of feature weights must equal to the number of columns in the input data")
} }
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info)) .Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE) return(TRUE)
} }
@ -568,11 +606,15 @@ slice.xgb.DMatrix <- function(object, idxset, ...) {
#' @method print xgb.DMatrix #' @method print xgb.DMatrix
#' @export #' @export
print.xgb.DMatrix <- function(x, verbose = FALSE, ...) { print.xgb.DMatrix <- function(x, verbose = FALSE, ...) {
if (.Call(XGCheckNullPtr_R, x)) {
cat("INVALID xgb.DMatrix object. Must be constructed anew.\n")
return(invisible(x))
}
cat('xgb.DMatrix dim:', nrow(x), 'x', ncol(x), ' info: ') cat('xgb.DMatrix dim:', nrow(x), 'x', ncol(x), ' info: ')
infos <- character(0) infos <- character(0)
if (length(getinfo(x, 'label')) > 0) infos <- 'label' if (xgb.DMatrix.hasinfo(x, 'label')) infos <- 'label'
if (length(getinfo(x, 'weight')) > 0) infos <- c(infos, 'weight') if (xgb.DMatrix.hasinfo(x, 'weight')) infos <- c(infos, 'weight')
if (length(getinfo(x, 'base_margin')) > 0) infos <- c(infos, 'base_margin') if (xgb.DMatrix.hasinfo(x, 'base_margin')) infos <- c(infos, 'base_margin')
if (length(infos) == 0) infos <- 'NA' if (length(infos) == 0) infos <- 'NA'
cat(infos) cat(infos)
cnames <- colnames(x) cnames <- colnames(x)

View File

@ -126,6 +126,9 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
early_stopping_rounds = NULL, maximize = NULL, callbacks = list(), ...) { early_stopping_rounds = NULL, maximize = NULL, callbacks = list(), ...) {
check.deprecation(...) check.deprecation(...)
if (inherits(data, "xgb.DMatrix") && .Call(XGCheckNullPtr_R, data)) {
stop("'data' is an invalid 'xgb.DMatrix' object. Must be constructed again.")
}
params <- check.booster.params(params, ...) params <- check.booster.params(params, ...)
# TODO: should we deprecate the redundant 'metrics' parameter? # TODO: should we deprecate the redundant 'metrics' parameter?
@ -136,7 +139,7 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
check.custom.eval() check.custom.eval()
# Check the labels # Check the labels
if ((inherits(data, 'xgb.DMatrix') && is.null(getinfo(data, 'label'))) || if ((inherits(data, 'xgb.DMatrix') && !xgb.DMatrix.hasinfo(data, 'label')) ||
(!inherits(data, 'xgb.DMatrix') && is.null(label))) { (!inherits(data, 'xgb.DMatrix') && is.null(label))) {
stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix") stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix")
} else if (inherits(data, 'xgb.DMatrix')) { } else if (inherits(data, 'xgb.DMatrix')) {

View File

@ -12,7 +12,7 @@ setinfo(object, ...)
\arguments{ \arguments{
\item{object}{Object of class "xgb.DMatrix"} \item{object}{Object of class "xgb.DMatrix"}
\item{...}{other parameters} \item{...}{Not used.}
\item{name}{the name of the field to get} \item{name}{the name of the field to get}

View File

@ -0,0 +1,32 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/xgb.DMatrix.R
\name{xgb.DMatrix.hasinfo}
\alias{xgb.DMatrix.hasinfo}
\title{Check whether DMatrix object has a field}
\usage{
xgb.DMatrix.hasinfo(object, info)
}
\arguments{
\item{object}{The DMatrix object to check for the given \code{info} field.}
\item{info}{The field to check for presence or absence in \code{object}.}
}
\description{
Checks whether an xgb.DMatrix object has a given field assigned to
it, such as weights, labels, etc.
}
\examples{
library(xgboost)
x <- matrix(1:10, nrow = 5)
dm <- xgb.DMatrix(x, nthread = 1)
# 'dm' so far doesn't have any fields set
xgb.DMatrix.hasinfo(dm, "label")
# Fields can be added after construction
setinfo(dm, "label", 1:5)
xgb.DMatrix.hasinfo(dm, "label")
}
\seealso{
\link{xgb.DMatrix}, \link{getinfo.xgb.DMatrix}, \link{setinfo.xgb.DMatrix}
}