diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE index bbb5ee225..40ede23a5 100644 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -28,6 +28,7 @@ export(setinfo) export(slice) export(xgb.Booster.complete) export(xgb.DMatrix) +export(xgb.DMatrix.hasinfo) export(xgb.DMatrix.save) export(xgb.attr) export(xgb.attributes) diff --git a/R-package/R/xgb.DMatrix.R b/R-package/R/xgb.DMatrix.R index 4b2bb0d2a..11d1105e6 100644 --- a/R-package/R/xgb.DMatrix.R +++ b/R-package/R/xgb.DMatrix.R @@ -163,7 +163,10 @@ xgb.DMatrix <- function( } dmat <- handle - attributes(dmat) <- list(class = "xgb.DMatrix") + attributes(dmat) <- list( + class = "xgb.DMatrix", + fields = new.env() + ) if (!is.null(label)) { setinfo(dmat, "label", label) @@ -199,6 +202,35 @@ xgb.DMatrix <- function( return(dmat) } +#' @title Check whether DMatrix object has a field +#' @description Checks whether an xgb.DMatrix object has a given field assigned to +#' it, such as weights, labels, etc. +#' @param object The DMatrix object to check for the given \code{info} field. +#' @param info The field to check for presence or absence in \code{object}. +#' @seealso \link{xgb.DMatrix}, \link{getinfo.xgb.DMatrix}, \link{setinfo.xgb.DMatrix} +#' @examples +#' library(xgboost) +#' x <- matrix(1:10, nrow = 5) +#' dm <- xgb.DMatrix(x, nthread = 1) +#' +#' # 'dm' so far doesn't have any fields set +#' xgb.DMatrix.hasinfo(dm, "label") +#' +#' # Fields can be added after construction +#' setinfo(dm, "label", 1:5) +#' xgb.DMatrix.hasinfo(dm, "label") +#' @export +xgb.DMatrix.hasinfo <- function(object, info) { + if (!inherits(object, "xgb.DMatrix")) { + stop("Object is not an 'xgb.DMatrix'.") + } + if (.Call(XGCheckNullPtr_R, object)) { + warning("xgb.DMatrix object is invalid. Must be constructed again.") + return(FALSE) + } + return(NVL(attr(object, "fields")[[info]], FALSE)) +} + # get dmatrix from data, label # internal helper method @@ -389,7 +421,7 @@ getinfo.xgb.DMatrix <- function(object, name, ...) { #' @param object Object of class "xgb.DMatrix" #' @param name the name of the field to get #' @param info the specific field of information to set -#' @param ... other parameters +#' @param ... Not used. #' #' @details #' See the documentation for \link{xgb.DMatrix} for possible fields that can be set @@ -418,6 +450,12 @@ setinfo <- function(object, ...) UseMethod("setinfo") #' @rdname setinfo #' @export setinfo.xgb.DMatrix <- function(object, name, info, ...) { + .internal.setinfo.xgb.DMatrix(object, name, info, ...) + attr(object, "fields")[[name]] <- TRUE + return(TRUE) +} + +.internal.setinfo.xgb.DMatrix <- function(object, name, info, ...) { if (name == "label") { if (NROW(info) != nrow(object)) stop("The length of labels must equal to the number of rows in the input data") @@ -425,19 +463,19 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) { return(TRUE) } if (name == "label_lower_bound") { - if (length(info) != nrow(object)) + if (NROW(info) != nrow(object)) stop("The length of lower-bound labels must equal to the number of rows in the input data") - .Call(XGDMatrixSetInfo_R, object, name, as.numeric(info)) + .Call(XGDMatrixSetInfo_R, object, name, info) return(TRUE) } if (name == "label_upper_bound") { - if (length(info) != nrow(object)) + if (NROW(info) != nrow(object)) stop("The length of upper-bound labels must equal to the number of rows in the input data") - .Call(XGDMatrixSetInfo_R, object, name, as.numeric(info)) + .Call(XGDMatrixSetInfo_R, object, name, info) return(TRUE) } if (name == "weight") { - .Call(XGDMatrixSetInfo_R, object, name, as.numeric(info)) + .Call(XGDMatrixSetInfo_R, object, name, info) return(TRUE) } if (name == "base_margin") { @@ -447,20 +485,20 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) { if (name == "group") { if (sum(info) != nrow(object)) stop("The sum of groups must equal to the number of rows in the input data") - .Call(XGDMatrixSetInfo_R, object, name, as.integer(info)) + .Call(XGDMatrixSetInfo_R, object, name, info) return(TRUE) } if (name == "qid") { if (NROW(info) != nrow(object)) stop("The length of qid assignments must equal to the number of rows in the input data") - .Call(XGDMatrixSetInfo_R, object, name, as.integer(info)) + .Call(XGDMatrixSetInfo_R, object, name, info) return(TRUE) } if (name == "feature_weights") { - if (length(info) != ncol(object)) { + if (NROW(info) != ncol(object)) { stop("The number of feature weights must equal to the number of columns in the input data") } - .Call(XGDMatrixSetInfo_R, object, name, as.numeric(info)) + .Call(XGDMatrixSetInfo_R, object, name, info) return(TRUE) } @@ -568,11 +606,15 @@ slice.xgb.DMatrix <- function(object, idxset, ...) { #' @method print xgb.DMatrix #' @export print.xgb.DMatrix <- function(x, verbose = FALSE, ...) { + if (.Call(XGCheckNullPtr_R, x)) { + cat("INVALID xgb.DMatrix object. Must be constructed anew.\n") + return(invisible(x)) + } cat('xgb.DMatrix dim:', nrow(x), 'x', ncol(x), ' info: ') infos <- character(0) - if (length(getinfo(x, 'label')) > 0) infos <- 'label' - if (length(getinfo(x, 'weight')) > 0) infos <- c(infos, 'weight') - if (length(getinfo(x, 'base_margin')) > 0) infos <- c(infos, 'base_margin') + if (xgb.DMatrix.hasinfo(x, 'label')) infos <- 'label' + if (xgb.DMatrix.hasinfo(x, 'weight')) infos <- c(infos, 'weight') + if (xgb.DMatrix.hasinfo(x, 'base_margin')) infos <- c(infos, 'base_margin') if (length(infos) == 0) infos <- 'NA' cat(infos) cnames <- colnames(x) diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R index 9e1ffeddc..1c17d86f0 100644 --- a/R-package/R/xgb.cv.R +++ b/R-package/R/xgb.cv.R @@ -126,6 +126,9 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing early_stopping_rounds = NULL, maximize = NULL, callbacks = list(), ...) { check.deprecation(...) + if (inherits(data, "xgb.DMatrix") && .Call(XGCheckNullPtr_R, data)) { + stop("'data' is an invalid 'xgb.DMatrix' object. Must be constructed again.") + } params <- check.booster.params(params, ...) # TODO: should we deprecate the redundant 'metrics' parameter? @@ -136,7 +139,7 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing check.custom.eval() # Check the labels - if ((inherits(data, 'xgb.DMatrix') && is.null(getinfo(data, 'label'))) || + if ((inherits(data, 'xgb.DMatrix') && !xgb.DMatrix.hasinfo(data, 'label')) || (!inherits(data, 'xgb.DMatrix') && is.null(label))) { stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix") } else if (inherits(data, 'xgb.DMatrix')) { diff --git a/R-package/man/setinfo.Rd b/R-package/man/setinfo.Rd index a8bc56b02..299e72675 100644 --- a/R-package/man/setinfo.Rd +++ b/R-package/man/setinfo.Rd @@ -12,7 +12,7 @@ setinfo(object, ...) \arguments{ \item{object}{Object of class "xgb.DMatrix"} -\item{...}{other parameters} +\item{...}{Not used.} \item{name}{the name of the field to get} diff --git a/R-package/man/xgb.DMatrix.hasinfo.Rd b/R-package/man/xgb.DMatrix.hasinfo.Rd new file mode 100644 index 000000000..308d9b42e --- /dev/null +++ b/R-package/man/xgb.DMatrix.hasinfo.Rd @@ -0,0 +1,32 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/xgb.DMatrix.R +\name{xgb.DMatrix.hasinfo} +\alias{xgb.DMatrix.hasinfo} +\title{Check whether DMatrix object has a field} +\usage{ +xgb.DMatrix.hasinfo(object, info) +} +\arguments{ +\item{object}{The DMatrix object to check for the given \code{info} field.} + +\item{info}{The field to check for presence or absence in \code{object}.} +} +\description{ +Checks whether an xgb.DMatrix object has a given field assigned to +it, such as weights, labels, etc. +} +\examples{ +library(xgboost) +x <- matrix(1:10, nrow = 5) +dm <- xgb.DMatrix(x, nthread = 1) + +# 'dm' so far doesn't have any fields set +xgb.DMatrix.hasinfo(dm, "label") + +# Fields can be added after construction +setinfo(dm, "label", 1:5) +xgb.DMatrix.hasinfo(dm, "label") +} +\seealso{ +\link{xgb.DMatrix}, \link{getinfo.xgb.DMatrix}, \link{setinfo.xgb.DMatrix} +}