[R] Move all DMatrix fields to function arguments (#9862)
This commit is contained in:
parent
1094d6015d
commit
562352101d
@ -8,13 +8,24 @@
|
|||||||
#' a \code{dgRMatrix} object,
|
#' a \code{dgRMatrix} object,
|
||||||
#' a \code{dsparseVector} object (only when making predictions from a fitted model, will be
|
#' a \code{dsparseVector} object (only when making predictions from a fitted model, will be
|
||||||
#' interpreted as a row vector), or a character string representing a filename.
|
#' interpreted as a row vector), or a character string representing a filename.
|
||||||
#' @param info a named list of additional information to store in the \code{xgb.DMatrix} object.
|
#' @param label Label of the training data.
|
||||||
#' See \code{\link{setinfo}} for the specific allowed kinds of
|
#' @param weight Weight for each instance.
|
||||||
|
#'
|
||||||
|
#' Note that, for ranking task, weights are per-group. In ranking task, one weight
|
||||||
|
#' is assigned to each group (not each data point). This is because we
|
||||||
|
#' only care about the relative ordering of data points within each group,
|
||||||
|
#' so it doesn't make sense to assign weights to individual data points.
|
||||||
|
#' @param base_margin Base margin used for boosting from existing model.
|
||||||
#' @param missing a float value to represents missing values in data (used only when input is a dense matrix).
|
#' @param missing a float value to represents missing values in data (used only when input is a dense matrix).
|
||||||
#' It is useful when a 0 or some other extreme value represents missing values in data.
|
#' It is useful when a 0 or some other extreme value represents missing values in data.
|
||||||
#' @param silent whether to suppress printing an informational message after loading from a file.
|
#' @param silent whether to suppress printing an informational message after loading from a file.
|
||||||
|
#' @param feature_names Set names for features.
|
||||||
#' @param nthread Number of threads used for creating DMatrix.
|
#' @param nthread Number of threads used for creating DMatrix.
|
||||||
#' @param ... the \code{info} data could be passed directly as parameters, without creating an \code{info} list.
|
#' @param group Group size for all ranking group.
|
||||||
|
#' @param qid Query ID for data samples, used for ranking.
|
||||||
|
#' @param label_lower_bound Lower bound for survival training.
|
||||||
|
#' @param label_upper_bound Upper bound for survival training.
|
||||||
|
#' @param feature_weights Set feature weights for column sampling.
|
||||||
#'
|
#'
|
||||||
#' @details
|
#' @details
|
||||||
#' Note that DMatrix objects are not serializable through R functions such as \code{saveRDS} or \code{save}.
|
#' Note that DMatrix objects are not serializable through R functions such as \code{saveRDS} or \code{save}.
|
||||||
@ -34,8 +45,24 @@
|
|||||||
#' dtrain <- xgb.DMatrix('xgb.DMatrix.data')
|
#' dtrain <- xgb.DMatrix('xgb.DMatrix.data')
|
||||||
#' if (file.exists('xgb.DMatrix.data')) file.remove('xgb.DMatrix.data')
|
#' if (file.exists('xgb.DMatrix.data')) file.remove('xgb.DMatrix.data')
|
||||||
#' @export
|
#' @export
|
||||||
xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthread = NULL, ...) {
|
xgb.DMatrix <- function(
|
||||||
cnames <- NULL
|
data,
|
||||||
|
label = NULL,
|
||||||
|
weight = NULL,
|
||||||
|
base_margin = NULL,
|
||||||
|
missing = NA,
|
||||||
|
silent = FALSE,
|
||||||
|
feature_names = colnames(data),
|
||||||
|
nthread = NULL,
|
||||||
|
group = NULL,
|
||||||
|
qid = NULL,
|
||||||
|
label_lower_bound = NULL,
|
||||||
|
label_upper_bound = NULL,
|
||||||
|
feature_weights = NULL
|
||||||
|
) {
|
||||||
|
if (!is.null(group) && !is.null(qid)) {
|
||||||
|
stop("Either one of 'group' or 'qid' should be NULL")
|
||||||
|
}
|
||||||
if (typeof(data) == "character") {
|
if (typeof(data) == "character") {
|
||||||
if (length(data) > 1)
|
if (length(data) > 1)
|
||||||
stop("'data' has class 'character' and length ", length(data),
|
stop("'data' has class 'character' and length ", length(data),
|
||||||
@ -44,7 +71,6 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
|
|||||||
handle <- .Call(XGDMatrixCreateFromFile_R, data, as.integer(silent))
|
handle <- .Call(XGDMatrixCreateFromFile_R, data, as.integer(silent))
|
||||||
} else if (is.matrix(data)) {
|
} else if (is.matrix(data)) {
|
||||||
handle <- .Call(XGDMatrixCreateFromMat_R, data, missing, as.integer(NVL(nthread, -1)))
|
handle <- .Call(XGDMatrixCreateFromMat_R, data, missing, as.integer(NVL(nthread, -1)))
|
||||||
cnames <- colnames(data)
|
|
||||||
} else if (inherits(data, "dgCMatrix")) {
|
} else if (inherits(data, "dgCMatrix")) {
|
||||||
handle <- .Call(
|
handle <- .Call(
|
||||||
XGDMatrixCreateFromCSC_R,
|
XGDMatrixCreateFromCSC_R,
|
||||||
@ -55,7 +81,6 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
|
|||||||
missing,
|
missing,
|
||||||
as.integer(NVL(nthread, -1))
|
as.integer(NVL(nthread, -1))
|
||||||
)
|
)
|
||||||
cnames <- colnames(data)
|
|
||||||
} else if (inherits(data, "dgRMatrix")) {
|
} else if (inherits(data, "dgRMatrix")) {
|
||||||
handle <- .Call(
|
handle <- .Call(
|
||||||
XGDMatrixCreateFromCSR_R,
|
XGDMatrixCreateFromCSR_R,
|
||||||
@ -66,7 +91,6 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
|
|||||||
missing,
|
missing,
|
||||||
as.integer(NVL(nthread, -1))
|
as.integer(NVL(nthread, -1))
|
||||||
)
|
)
|
||||||
cnames <- colnames(data)
|
|
||||||
} else if (inherits(data, "dsparseVector")) {
|
} else if (inherits(data, "dsparseVector")) {
|
||||||
indptr <- c(0L, as.integer(length(data@i)))
|
indptr <- c(0L, as.integer(length(data@i)))
|
||||||
ind <- as.integer(data@i) - 1L
|
ind <- as.integer(data@i) - 1L
|
||||||
@ -82,17 +106,38 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
|
|||||||
} else {
|
} else {
|
||||||
stop("xgb.DMatrix does not support construction from ", typeof(data))
|
stop("xgb.DMatrix does not support construction from ", typeof(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
dmat <- handle
|
dmat <- handle
|
||||||
attributes(dmat) <- list(class = "xgb.DMatrix")
|
attributes(dmat) <- list(class = "xgb.DMatrix")
|
||||||
if (!is.null(cnames)) {
|
|
||||||
setinfo(dmat, "feature_name", cnames)
|
if (!is.null(label)) {
|
||||||
|
setinfo(dmat, "label", label)
|
||||||
|
}
|
||||||
|
if (!is.null(weight)) {
|
||||||
|
setinfo(dmat, "weight", weight)
|
||||||
|
}
|
||||||
|
if (!is.null(base_margin)) {
|
||||||
|
setinfo(dmat, "base_margin", base_margin)
|
||||||
|
}
|
||||||
|
if (!is.null(feature_names)) {
|
||||||
|
setinfo(dmat, "feature_name", feature_names)
|
||||||
|
}
|
||||||
|
if (!is.null(group)) {
|
||||||
|
setinfo(dmat, "group", group)
|
||||||
|
}
|
||||||
|
if (!is.null(qid)) {
|
||||||
|
setinfo(dmat, "qid", qid)
|
||||||
|
}
|
||||||
|
if (!is.null(label_lower_bound)) {
|
||||||
|
setinfo(dmat, "label_lower_bound", label_lower_bound)
|
||||||
|
}
|
||||||
|
if (!is.null(label_upper_bound)) {
|
||||||
|
setinfo(dmat, "label_upper_bound", label_upper_bound)
|
||||||
|
}
|
||||||
|
if (!is.null(feature_weights)) {
|
||||||
|
setinfo(dmat, "feature_weights", feature_weights)
|
||||||
}
|
}
|
||||||
|
|
||||||
info <- append(info, list(...))
|
|
||||||
for (i in seq_along(info)) {
|
|
||||||
p <- info[i]
|
|
||||||
setinfo(dmat, names(p), p[[1]])
|
|
||||||
}
|
|
||||||
return(dmat)
|
return(dmat)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -211,14 +256,20 @@ dimnames.xgb.DMatrix <- function(x) {
|
|||||||
#' The \code{name} field can be one of the following:
|
#' The \code{name} field can be one of the following:
|
||||||
#'
|
#'
|
||||||
#' \itemize{
|
#' \itemize{
|
||||||
#' \item \code{label}: label XGBoost learn from ;
|
#' \item \code{label}
|
||||||
#' \item \code{weight}: to do a weight rescale ;
|
#' \item \code{weight}
|
||||||
#' \item \code{base_margin}: base margin is the base prediction XGBoost will boost from ;
|
#' \item \code{base_margin}
|
||||||
#' \item \code{nrow}: number of rows of the \code{xgb.DMatrix}.
|
#' \item \code{label_lower_bound}
|
||||||
#'
|
#' \item \code{label_upper_bound}
|
||||||
|
#' \item \code{group}
|
||||||
|
#' \item \code{feature_type}
|
||||||
|
#' \item \code{feature_name}
|
||||||
|
#' \item \code{nrow}
|
||||||
#' }
|
#' }
|
||||||
|
#' See the documentation for \link{xgb.DMatrix} for more information about these fields.
|
||||||
#'
|
#'
|
||||||
#' \code{group} can be setup by \code{setinfo} but can't be retrieved by \code{getinfo}.
|
#' Note that, while 'qid' cannot be retrieved, it's possible to get the equivalent 'group'
|
||||||
|
#' for a DMatrix that had 'qid' assigned.
|
||||||
#'
|
#'
|
||||||
#' @examples
|
#' @examples
|
||||||
#' data(agaricus.train, package='xgboost')
|
#' data(agaricus.train, package='xgboost')
|
||||||
@ -236,24 +287,37 @@ getinfo <- function(object, ...) UseMethod("getinfo")
|
|||||||
#' @rdname getinfo
|
#' @rdname getinfo
|
||||||
#' @export
|
#' @export
|
||||||
getinfo.xgb.DMatrix <- function(object, name, ...) {
|
getinfo.xgb.DMatrix <- function(object, name, ...) {
|
||||||
|
allowed_int_fields <- 'group'
|
||||||
|
allowed_float_fields <- c(
|
||||||
|
'label', 'weight', 'base_margin',
|
||||||
|
'label_lower_bound', 'label_upper_bound'
|
||||||
|
)
|
||||||
|
allowed_str_fields <- c("feature_type", "feature_name")
|
||||||
|
allowed_fields <- c(allowed_float_fields, allowed_int_fields, allowed_str_fields, 'nrow')
|
||||||
|
|
||||||
if (typeof(name) != "character" ||
|
if (typeof(name) != "character" ||
|
||||||
length(name) != 1 ||
|
length(name) != 1 ||
|
||||||
!name %in% c('label', 'weight', 'base_margin', 'nrow',
|
!name %in% allowed_fields) {
|
||||||
'label_lower_bound', 'label_upper_bound', "feature_type", "feature_name")) {
|
stop("getinfo: name must be one of the following\n",
|
||||||
stop(
|
paste(paste0("'", allowed_fields, "'"), collapse = ", "))
|
||||||
"getinfo: name must be one of the following\n",
|
|
||||||
" 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound', 'feature_type', 'feature_name'"
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
if (name == "feature_name" || name == "feature_type") {
|
if (name == "nrow") {
|
||||||
|
ret <- nrow(object)
|
||||||
|
} else if (name %in% allowed_str_fields) {
|
||||||
ret <- .Call(XGDMatrixGetStrFeatureInfo_R, object, name)
|
ret <- .Call(XGDMatrixGetStrFeatureInfo_R, object, name)
|
||||||
} else if (name != "nrow") {
|
} else if (name %in% allowed_float_fields) {
|
||||||
ret <- .Call(XGDMatrixGetInfo_R, object, name)
|
ret <- .Call(XGDMatrixGetFloatInfo_R, object, name)
|
||||||
|
if (length(ret) > nrow(object)) {
|
||||||
|
ret <- matrix(ret, nrow = nrow(object), byrow = TRUE)
|
||||||
|
}
|
||||||
|
} else if (name %in% allowed_int_fields) {
|
||||||
|
if (name == "group") {
|
||||||
|
name <- "group_ptr"
|
||||||
|
}
|
||||||
|
ret <- .Call(XGDMatrixGetUIntInfo_R, object, name)
|
||||||
if (length(ret) > nrow(object)) {
|
if (length(ret) > nrow(object)) {
|
||||||
ret <- matrix(ret, nrow = nrow(object), byrow = TRUE)
|
ret <- matrix(ret, nrow = nrow(object), byrow = TRUE)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
ret <- nrow(object)
|
|
||||||
}
|
}
|
||||||
if (length(ret) == 0) return(NULL)
|
if (length(ret) == 0) return(NULL)
|
||||||
return(ret)
|
return(ret)
|
||||||
@ -270,13 +334,15 @@ getinfo.xgb.DMatrix <- function(object, name, ...) {
|
|||||||
#' @param ... other parameters
|
#' @param ... other parameters
|
||||||
#'
|
#'
|
||||||
#' @details
|
#' @details
|
||||||
#' The \code{name} field can be one of the following:
|
#' See the documentation for \link{xgb.DMatrix} for possible fields that can be set
|
||||||
|
#' (which correspond to arguments in that function).
|
||||||
#'
|
#'
|
||||||
#' \itemize{
|
#' Note that the following fields are allowed in the construction of an \code{xgb.DMatrix}
|
||||||
#' \item \code{label}: label XGBoost learn from ;
|
#' but \bold{aren't} allowed here:\itemize{
|
||||||
#' \item \code{weight}: to do a weight rescale ;
|
#' \item data
|
||||||
#' \item \code{base_margin}: base margin is the base prediction XGBoost will boost from ;
|
#' \item missing
|
||||||
#' \item \code{group}: number of rows in each group (to use with \code{rank:pairwise} objective).
|
#' \item silent
|
||||||
|
#' \item nthread
|
||||||
#' }
|
#' }
|
||||||
#'
|
#'
|
||||||
#' @examples
|
#' @examples
|
||||||
@ -328,6 +394,12 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
|
|||||||
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info))
|
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info))
|
||||||
return(TRUE)
|
return(TRUE)
|
||||||
}
|
}
|
||||||
|
if (name == "qid") {
|
||||||
|
if (NROW(info) != nrow(object))
|
||||||
|
stop("The length of qid assignments must equal to the number of rows in the input data")
|
||||||
|
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info))
|
||||||
|
return(TRUE)
|
||||||
|
}
|
||||||
if (name == "feature_weights") {
|
if (name == "feature_weights") {
|
||||||
if (length(info) != ncol(object)) {
|
if (length(info) != ncol(object)) {
|
||||||
stop("The number of feature weights must equal to the number of columns in the input data")
|
stop("The number of feature weights must equal to the number of columns in the input data")
|
||||||
|
|||||||
@ -23,14 +23,20 @@ Get information of an xgb.DMatrix object
|
|||||||
The \code{name} field can be one of the following:
|
The \code{name} field can be one of the following:
|
||||||
|
|
||||||
\itemize{
|
\itemize{
|
||||||
\item \code{label}: label XGBoost learn from ;
|
\item \code{label}
|
||||||
\item \code{weight}: to do a weight rescale ;
|
\item \code{weight}
|
||||||
\item \code{base_margin}: base margin is the base prediction XGBoost will boost from ;
|
\item \code{base_margin}
|
||||||
\item \code{nrow}: number of rows of the \code{xgb.DMatrix}.
|
\item \code{label_lower_bound}
|
||||||
|
\item \code{label_upper_bound}
|
||||||
|
\item \code{group}
|
||||||
|
\item \code{feature_type}
|
||||||
|
\item \code{feature_name}
|
||||||
|
\item \code{nrow}
|
||||||
}
|
}
|
||||||
|
See the documentation for \link{xgb.DMatrix} for more information about these fields.
|
||||||
|
|
||||||
\code{group} can be setup by \code{setinfo} but can't be retrieved by \code{getinfo}.
|
Note that, while 'qid' cannot be retrieved, it's possible to get the equivalent 'group'
|
||||||
|
for a DMatrix that had 'qid' assigned.
|
||||||
}
|
}
|
||||||
\examples{
|
\examples{
|
||||||
data(agaricus.train, package='xgboost')
|
data(agaricus.train, package='xgboost')
|
||||||
|
|||||||
@ -22,13 +22,15 @@ setinfo(object, ...)
|
|||||||
Set information of an xgb.DMatrix object
|
Set information of an xgb.DMatrix object
|
||||||
}
|
}
|
||||||
\details{
|
\details{
|
||||||
The \code{name} field can be one of the following:
|
See the documentation for \link{xgb.DMatrix} for possible fields that can be set
|
||||||
|
(which correspond to arguments in that function).
|
||||||
|
|
||||||
\itemize{
|
Note that the following fields are allowed in the construction of an \code{xgb.DMatrix}
|
||||||
\item \code{label}: label XGBoost learn from ;
|
but \bold{aren't} allowed here:\itemize{
|
||||||
\item \code{weight}: to do a weight rescale ;
|
\item data
|
||||||
\item \code{base_margin}: base margin is the base prediction XGBoost will boost from ;
|
\item missing
|
||||||
\item \code{group}: number of rows in each group (to use with \code{rank:pairwise} objective).
|
\item silent
|
||||||
|
\item nthread
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
\examples{
|
\examples{
|
||||||
|
|||||||
@ -6,11 +6,18 @@
|
|||||||
\usage{
|
\usage{
|
||||||
xgb.DMatrix(
|
xgb.DMatrix(
|
||||||
data,
|
data,
|
||||||
info = list(),
|
label = NULL,
|
||||||
|
weight = NULL,
|
||||||
|
base_margin = NULL,
|
||||||
missing = NA,
|
missing = NA,
|
||||||
silent = FALSE,
|
silent = FALSE,
|
||||||
|
feature_names = colnames(data),
|
||||||
nthread = NULL,
|
nthread = NULL,
|
||||||
...
|
group = NULL,
|
||||||
|
qid = NULL,
|
||||||
|
label_lower_bound = NULL,
|
||||||
|
label_upper_bound = NULL,
|
||||||
|
feature_weights = NULL
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
\arguments{
|
\arguments{
|
||||||
@ -19,17 +26,35 @@ a \code{dgRMatrix} object,
|
|||||||
a \code{dsparseVector} object (only when making predictions from a fitted model, will be
|
a \code{dsparseVector} object (only when making predictions from a fitted model, will be
|
||||||
interpreted as a row vector), or a character string representing a filename.}
|
interpreted as a row vector), or a character string representing a filename.}
|
||||||
|
|
||||||
\item{info}{a named list of additional information to store in the \code{xgb.DMatrix} object.
|
\item{label}{Label of the training data.}
|
||||||
See \code{\link{setinfo}} for the specific allowed kinds of}
|
|
||||||
|
\item{weight}{Weight for each instance.
|
||||||
|
|
||||||
|
Note that, for ranking task, weights are per-group. In ranking task, one weight
|
||||||
|
is assigned to each group (not each data point). This is because we
|
||||||
|
only care about the relative ordering of data points within each group,
|
||||||
|
so it doesn't make sense to assign weights to individual data points.}
|
||||||
|
|
||||||
|
\item{base_margin}{Base margin used for boosting from existing model.}
|
||||||
|
|
||||||
\item{missing}{a float value to represents missing values in data (used only when input is a dense matrix).
|
\item{missing}{a float value to represents missing values in data (used only when input is a dense matrix).
|
||||||
It is useful when a 0 or some other extreme value represents missing values in data.}
|
It is useful when a 0 or some other extreme value represents missing values in data.}
|
||||||
|
|
||||||
\item{silent}{whether to suppress printing an informational message after loading from a file.}
|
\item{silent}{whether to suppress printing an informational message after loading from a file.}
|
||||||
|
|
||||||
|
\item{feature_names}{Set names for features.}
|
||||||
|
|
||||||
\item{nthread}{Number of threads used for creating DMatrix.}
|
\item{nthread}{Number of threads used for creating DMatrix.}
|
||||||
|
|
||||||
\item{...}{the \code{info} data could be passed directly as parameters, without creating an \code{info} list.}
|
\item{group}{Group size for all ranking group.}
|
||||||
|
|
||||||
|
\item{qid}{Query ID for data samples, used for ranking.}
|
||||||
|
|
||||||
|
\item{label_lower_bound}{Lower bound for survival training.}
|
||||||
|
|
||||||
|
\item{label_upper_bound}{Upper bound for survival training.}
|
||||||
|
|
||||||
|
\item{feature_weights}{Set feature weights for column sampling.}
|
||||||
}
|
}
|
||||||
\description{
|
\description{
|
||||||
Construct xgb.DMatrix object from either a dense matrix, a sparse matrix, or a local file.
|
Construct xgb.DMatrix object from either a dense matrix, a sparse matrix, or a local file.
|
||||||
|
|||||||
@ -39,7 +39,8 @@ extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
|
|||||||
extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
|
extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
|
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP);
|
extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP);
|
extern SEXP XGDMatrixGetFloatInfo_R(SEXP, SEXP);
|
||||||
|
extern SEXP XGDMatrixGetUIntInfo_R(SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixGetStrFeatureInfo_R(SEXP, SEXP);
|
extern SEXP XGDMatrixGetStrFeatureInfo_R(SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixNumCol_R(SEXP);
|
extern SEXP XGDMatrixNumCol_R(SEXP);
|
||||||
extern SEXP XGDMatrixNumRow_R(SEXP);
|
extern SEXP XGDMatrixNumRow_R(SEXP);
|
||||||
@ -76,7 +77,8 @@ static const R_CallMethodDef CallEntries[] = {
|
|||||||
{"XGDMatrixCreateFromCSR_R", (DL_FUNC) &XGDMatrixCreateFromCSR_R, 6},
|
{"XGDMatrixCreateFromCSR_R", (DL_FUNC) &XGDMatrixCreateFromCSR_R, 6},
|
||||||
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
|
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
|
||||||
{"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3},
|
{"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3},
|
||||||
{"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2},
|
{"XGDMatrixGetFloatInfo_R", (DL_FUNC) &XGDMatrixGetFloatInfo_R, 2},
|
||||||
|
{"XGDMatrixGetUIntInfo_R", (DL_FUNC) &XGDMatrixGetUIntInfo_R, 2},
|
||||||
{"XGDMatrixGetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixGetStrFeatureInfo_R, 2},
|
{"XGDMatrixGetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixGetStrFeatureInfo_R, 2},
|
||||||
{"XGDMatrixNumCol_R", (DL_FUNC) &XGDMatrixNumCol_R, 1},
|
{"XGDMatrixNumCol_R", (DL_FUNC) &XGDMatrixNumCol_R, 1},
|
||||||
{"XGDMatrixNumRow_R", (DL_FUNC) &XGDMatrixNumRow_R, 1},
|
{"XGDMatrixNumRow_R", (DL_FUNC) &XGDMatrixNumRow_R, 1},
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
#include <xgboost/logging.h>
|
#include <xgboost/logging.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
@ -412,17 +413,27 @@ XGB_DLL SEXP XGDMatrixGetStrFeatureInfo_R(SEXP handle, SEXP field) {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) {
|
XGB_DLL SEXP XGDMatrixGetFloatInfo_R(SEXP handle, SEXP field) {
|
||||||
SEXP ret;
|
SEXP ret;
|
||||||
R_API_BEGIN();
|
R_API_BEGIN();
|
||||||
bst_ulong olen;
|
bst_ulong olen;
|
||||||
const float *res;
|
const float *res;
|
||||||
CHECK_CALL(XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), &olen, &res));
|
CHECK_CALL(XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), &olen, &res));
|
||||||
ret = PROTECT(allocVector(REALSXP, olen));
|
ret = PROTECT(allocVector(REALSXP, olen));
|
||||||
double *ret_ = REAL(ret);
|
std::copy(res, res + olen, REAL(ret));
|
||||||
for (size_t i = 0; i < olen; ++i) {
|
R_API_END();
|
||||||
ret_[i] = res[i];
|
UNPROTECT(1);
|
||||||
}
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL SEXP XGDMatrixGetUIntInfo_R(SEXP handle, SEXP field) {
|
||||||
|
SEXP ret;
|
||||||
|
R_API_BEGIN();
|
||||||
|
bst_ulong olen;
|
||||||
|
const unsigned *res;
|
||||||
|
CHECK_CALL(XGDMatrixGetUIntInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), &olen, &res));
|
||||||
|
ret = PROTECT(allocVector(INTSXP, olen));
|
||||||
|
std::copy(res, res + olen, INTEGER(ret));
|
||||||
R_API_END();
|
R_API_END();
|
||||||
UNPROTECT(1);
|
UNPROTECT(1);
|
||||||
return ret;
|
return ret;
|
||||||
|
|||||||
@ -106,12 +106,20 @@ XGB_DLL SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent);
|
|||||||
XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array);
|
XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief get info vector from matrix
|
* \brief get info vector (float type) from matrix
|
||||||
* \param handle a instance of data matrix
|
* \param handle a instance of data matrix
|
||||||
* \param field field name
|
* \param field field name
|
||||||
* \return info vector
|
* \return info vector
|
||||||
*/
|
*/
|
||||||
XGB_DLL SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field);
|
XGB_DLL SEXP XGDMatrixGetFloatInfo_R(SEXP handle, SEXP field);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief get info vector (uint type) from matrix
|
||||||
|
* \param handle a instance of data matrix
|
||||||
|
* \param field field name
|
||||||
|
* \return info vector
|
||||||
|
*/
|
||||||
|
XGB_DLL SEXP XGDMatrixGetUIntInfo_R(SEXP handle, SEXP field);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief return number of rows
|
* \brief return number of rows
|
||||||
|
|||||||
@ -56,7 +56,7 @@ test_that("parameter validation works", {
|
|||||||
y <- d[, "x1"] + d[, "x2"]^2 +
|
y <- d[, "x1"] + d[, "x2"]^2 +
|
||||||
ifelse(d[, "x3"] > .5, d[, "x3"]^2, 2^d[, "x3"]) +
|
ifelse(d[, "x3"] > .5, d[, "x3"]^2, 2^d[, "x3"]) +
|
||||||
rnorm(10)
|
rnorm(10)
|
||||||
dtrain <- xgb.DMatrix(data = d, info = list(label = y), nthread = n_threads)
|
dtrain <- xgb.DMatrix(data = d, label = y, nthread = n_threads)
|
||||||
|
|
||||||
correct <- function() {
|
correct <- function() {
|
||||||
params <- list(
|
params <- list(
|
||||||
@ -124,7 +124,7 @@ test_that("dart prediction works", {
|
|||||||
expect_false(all(matrix(pred_by_xgboost_0, byrow = TRUE) == matrix(pred_by_xgboost_2, byrow = TRUE)))
|
expect_false(all(matrix(pred_by_xgboost_0, byrow = TRUE) == matrix(pred_by_xgboost_2, byrow = TRUE)))
|
||||||
|
|
||||||
set.seed(1994)
|
set.seed(1994)
|
||||||
dtrain <- xgb.DMatrix(data = d, info = list(label = y), nthread = n_threads)
|
dtrain <- xgb.DMatrix(data = d, label = y, nthread = n_threads)
|
||||||
booster_by_train <- xgb.train(
|
booster_by_train <- xgb.train(
|
||||||
params = list(
|
params = list(
|
||||||
booster = "dart",
|
booster = "dart",
|
||||||
@ -186,7 +186,7 @@ test_that("train and predict softprob", {
|
|||||||
x3 = rnorm(100)
|
x3 = rnorm(100)
|
||||||
)
|
)
|
||||||
y <- sample.int(10, 100, replace = TRUE) - 1
|
y <- sample.int(10, 100, replace = TRUE) - 1
|
||||||
dtrain <- xgb.DMatrix(data = d, info = list(label = y), nthread = n_threads)
|
dtrain <- xgb.DMatrix(data = d, label = y, nthread = n_threads)
|
||||||
booster <- xgb.train(
|
booster <- xgb.train(
|
||||||
params = list(tree_method = "hist", nthread = n_threads),
|
params = list(tree_method = "hist", nthread = n_threads),
|
||||||
data = dtrain, nrounds = 4, num_class = 10,
|
data = dtrain, nrounds = 4, num_class = 10,
|
||||||
@ -643,3 +643,28 @@ test_that("Can use multi-output labels with custom objectives", {
|
|||||||
expect_true(cor(y, pred[, 1]) > 0.9)
|
expect_true(cor(y, pred[, 1]) > 0.9)
|
||||||
expect_true(cor(y, pred[, 2]) < -0.9)
|
expect_true(cor(y, pred[, 2]) < -0.9)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("Can use ranking objectives with either 'qid' or 'group'", {
|
||||||
|
set.seed(123)
|
||||||
|
x <- matrix(rnorm(100 * 10), nrow = 100)
|
||||||
|
y <- sample(2, size = 100, replace = TRUE) - 1
|
||||||
|
qid <- c(rep(1, 20), rep(2, 20), rep(3, 60))
|
||||||
|
gr <- c(20, 20, 60)
|
||||||
|
|
||||||
|
dmat_qid <- xgb.DMatrix(x, label = y, qid = qid)
|
||||||
|
dmat_gr <- xgb.DMatrix(x, label = y, group = gr)
|
||||||
|
|
||||||
|
params <- list(tree_method = "hist",
|
||||||
|
lambdarank_num_pair_per_sample = 8,
|
||||||
|
objective = "rank:ndcg",
|
||||||
|
lambdarank_pair_method = "topk",
|
||||||
|
nthread = n_threads)
|
||||||
|
set.seed(123)
|
||||||
|
model_qid <- xgb.train(params, dmat_qid, nrounds = 5)
|
||||||
|
set.seed(123)
|
||||||
|
model_gr <- xgb.train(params, dmat_gr, nrounds = 5)
|
||||||
|
|
||||||
|
pred_qid <- predict(model_qid, x)
|
||||||
|
pred_gr <- predict(model_gr, x)
|
||||||
|
expect_equal(pred_qid, pred_gr)
|
||||||
|
})
|
||||||
|
|||||||
@ -305,3 +305,20 @@ test_that("xgb.DMatrix: error on three-dimensional array", {
|
|||||||
dim(y) <- c(50, 4, 2)
|
dim(y) <- c(50, 4, 2)
|
||||||
expect_error(xgb.DMatrix(data = x, label = y))
|
expect_error(xgb.DMatrix(data = x, label = y))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("xgb.DMatrix: can get group for both 'qid' and 'group' constructors", {
|
||||||
|
set.seed(123)
|
||||||
|
x <- matrix(rnorm(1000), nrow = 100)
|
||||||
|
group <- c(20, 20, 60)
|
||||||
|
qid <- c(rep(1, 20), rep(2, 20), rep(3, 60))
|
||||||
|
|
||||||
|
gr_mat <- xgb.DMatrix(x, group = group)
|
||||||
|
qid_mat <- xgb.DMatrix(x, qid = qid)
|
||||||
|
|
||||||
|
info_gr <- getinfo(gr_mat, "group")
|
||||||
|
info_qid <- getinfo(qid_mat, "group")
|
||||||
|
expect_equal(info_gr, info_qid)
|
||||||
|
|
||||||
|
expected_gr <- c(0, 20, 40, 100)
|
||||||
|
expect_equal(info_gr, expected_gr)
|
||||||
|
})
|
||||||
|
|||||||
@ -47,7 +47,7 @@ test_that("interaction constraints scientific representation", {
|
|||||||
d <- matrix(rexp(rows, rate = .1), nrow = rows, ncol = cols)
|
d <- matrix(rexp(rows, rate = .1), nrow = rows, ncol = cols)
|
||||||
y <- rnorm(rows)
|
y <- rnorm(rows)
|
||||||
|
|
||||||
dtrain <- xgb.DMatrix(data = d, info = list(label = y), nthread = n_threads)
|
dtrain <- xgb.DMatrix(data = d, label = y, nthread = n_threads)
|
||||||
inc <- list(c(seq.int(from = 0, to = cols, by = 1)))
|
inc <- list(c(seq.int(from = 0, to = cols, by = 1)))
|
||||||
|
|
||||||
with_inc <- xgb.train(
|
with_inc <- xgb.train(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user