[R] Move all DMatrix fields to function arguments (#9862)
This commit is contained in:
parent
1094d6015d
commit
562352101d
@ -8,13 +8,24 @@
|
||||
#' a \code{dgRMatrix} object,
|
||||
#' a \code{dsparseVector} object (only when making predictions from a fitted model, will be
|
||||
#' interpreted as a row vector), or a character string representing a filename.
|
||||
#' @param info a named list of additional information to store in the \code{xgb.DMatrix} object.
|
||||
#' See \code{\link{setinfo}} for the specific allowed kinds of
|
||||
#' @param label Label of the training data.
|
||||
#' @param weight Weight for each instance.
|
||||
#'
|
||||
#' Note that, for ranking task, weights are per-group. In ranking task, one weight
|
||||
#' is assigned to each group (not each data point). This is because we
|
||||
#' only care about the relative ordering of data points within each group,
|
||||
#' so it doesn't make sense to assign weights to individual data points.
|
||||
#' @param base_margin Base margin used for boosting from existing model.
|
||||
#' @param missing a float value to represents missing values in data (used only when input is a dense matrix).
|
||||
#' It is useful when a 0 or some other extreme value represents missing values in data.
|
||||
#' @param silent whether to suppress printing an informational message after loading from a file.
|
||||
#' @param feature_names Set names for features.
|
||||
#' @param nthread Number of threads used for creating DMatrix.
|
||||
#' @param ... the \code{info} data could be passed directly as parameters, without creating an \code{info} list.
|
||||
#' @param group Group size for all ranking group.
|
||||
#' @param qid Query ID for data samples, used for ranking.
|
||||
#' @param label_lower_bound Lower bound for survival training.
|
||||
#' @param label_upper_bound Upper bound for survival training.
|
||||
#' @param feature_weights Set feature weights for column sampling.
|
||||
#'
|
||||
#' @details
|
||||
#' Note that DMatrix objects are not serializable through R functions such as \code{saveRDS} or \code{save}.
|
||||
@ -34,8 +45,24 @@
|
||||
#' dtrain <- xgb.DMatrix('xgb.DMatrix.data')
|
||||
#' if (file.exists('xgb.DMatrix.data')) file.remove('xgb.DMatrix.data')
|
||||
#' @export
|
||||
xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthread = NULL, ...) {
|
||||
cnames <- NULL
|
||||
xgb.DMatrix <- function(
|
||||
data,
|
||||
label = NULL,
|
||||
weight = NULL,
|
||||
base_margin = NULL,
|
||||
missing = NA,
|
||||
silent = FALSE,
|
||||
feature_names = colnames(data),
|
||||
nthread = NULL,
|
||||
group = NULL,
|
||||
qid = NULL,
|
||||
label_lower_bound = NULL,
|
||||
label_upper_bound = NULL,
|
||||
feature_weights = NULL
|
||||
) {
|
||||
if (!is.null(group) && !is.null(qid)) {
|
||||
stop("Either one of 'group' or 'qid' should be NULL")
|
||||
}
|
||||
if (typeof(data) == "character") {
|
||||
if (length(data) > 1)
|
||||
stop("'data' has class 'character' and length ", length(data),
|
||||
@ -44,7 +71,6 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
|
||||
handle <- .Call(XGDMatrixCreateFromFile_R, data, as.integer(silent))
|
||||
} else if (is.matrix(data)) {
|
||||
handle <- .Call(XGDMatrixCreateFromMat_R, data, missing, as.integer(NVL(nthread, -1)))
|
||||
cnames <- colnames(data)
|
||||
} else if (inherits(data, "dgCMatrix")) {
|
||||
handle <- .Call(
|
||||
XGDMatrixCreateFromCSC_R,
|
||||
@ -55,7 +81,6 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
|
||||
missing,
|
||||
as.integer(NVL(nthread, -1))
|
||||
)
|
||||
cnames <- colnames(data)
|
||||
} else if (inherits(data, "dgRMatrix")) {
|
||||
handle <- .Call(
|
||||
XGDMatrixCreateFromCSR_R,
|
||||
@ -66,7 +91,6 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
|
||||
missing,
|
||||
as.integer(NVL(nthread, -1))
|
||||
)
|
||||
cnames <- colnames(data)
|
||||
} else if (inherits(data, "dsparseVector")) {
|
||||
indptr <- c(0L, as.integer(length(data@i)))
|
||||
ind <- as.integer(data@i) - 1L
|
||||
@ -82,17 +106,38 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
|
||||
} else {
|
||||
stop("xgb.DMatrix does not support construction from ", typeof(data))
|
||||
}
|
||||
|
||||
dmat <- handle
|
||||
attributes(dmat) <- list(class = "xgb.DMatrix")
|
||||
if (!is.null(cnames)) {
|
||||
setinfo(dmat, "feature_name", cnames)
|
||||
|
||||
if (!is.null(label)) {
|
||||
setinfo(dmat, "label", label)
|
||||
}
|
||||
if (!is.null(weight)) {
|
||||
setinfo(dmat, "weight", weight)
|
||||
}
|
||||
if (!is.null(base_margin)) {
|
||||
setinfo(dmat, "base_margin", base_margin)
|
||||
}
|
||||
if (!is.null(feature_names)) {
|
||||
setinfo(dmat, "feature_name", feature_names)
|
||||
}
|
||||
if (!is.null(group)) {
|
||||
setinfo(dmat, "group", group)
|
||||
}
|
||||
if (!is.null(qid)) {
|
||||
setinfo(dmat, "qid", qid)
|
||||
}
|
||||
if (!is.null(label_lower_bound)) {
|
||||
setinfo(dmat, "label_lower_bound", label_lower_bound)
|
||||
}
|
||||
if (!is.null(label_upper_bound)) {
|
||||
setinfo(dmat, "label_upper_bound", label_upper_bound)
|
||||
}
|
||||
if (!is.null(feature_weights)) {
|
||||
setinfo(dmat, "feature_weights", feature_weights)
|
||||
}
|
||||
|
||||
info <- append(info, list(...))
|
||||
for (i in seq_along(info)) {
|
||||
p <- info[i]
|
||||
setinfo(dmat, names(p), p[[1]])
|
||||
}
|
||||
return(dmat)
|
||||
}
|
||||
|
||||
@ -211,14 +256,20 @@ dimnames.xgb.DMatrix <- function(x) {
|
||||
#' The \code{name} field can be one of the following:
|
||||
#'
|
||||
#' \itemize{
|
||||
#' \item \code{label}: label XGBoost learn from ;
|
||||
#' \item \code{weight}: to do a weight rescale ;
|
||||
#' \item \code{base_margin}: base margin is the base prediction XGBoost will boost from ;
|
||||
#' \item \code{nrow}: number of rows of the \code{xgb.DMatrix}.
|
||||
#'
|
||||
#' \item \code{label}
|
||||
#' \item \code{weight}
|
||||
#' \item \code{base_margin}
|
||||
#' \item \code{label_lower_bound}
|
||||
#' \item \code{label_upper_bound}
|
||||
#' \item \code{group}
|
||||
#' \item \code{feature_type}
|
||||
#' \item \code{feature_name}
|
||||
#' \item \code{nrow}
|
||||
#' }
|
||||
#' See the documentation for \link{xgb.DMatrix} for more information about these fields.
|
||||
#'
|
||||
#' \code{group} can be setup by \code{setinfo} but can't be retrieved by \code{getinfo}.
|
||||
#' Note that, while 'qid' cannot be retrieved, it's possible to get the equivalent 'group'
|
||||
#' for a DMatrix that had 'qid' assigned.
|
||||
#'
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
@ -236,24 +287,37 @@ getinfo <- function(object, ...) UseMethod("getinfo")
|
||||
#' @rdname getinfo
|
||||
#' @export
|
||||
getinfo.xgb.DMatrix <- function(object, name, ...) {
|
||||
allowed_int_fields <- 'group'
|
||||
allowed_float_fields <- c(
|
||||
'label', 'weight', 'base_margin',
|
||||
'label_lower_bound', 'label_upper_bound'
|
||||
)
|
||||
allowed_str_fields <- c("feature_type", "feature_name")
|
||||
allowed_fields <- c(allowed_float_fields, allowed_int_fields, allowed_str_fields, 'nrow')
|
||||
|
||||
if (typeof(name) != "character" ||
|
||||
length(name) != 1 ||
|
||||
!name %in% c('label', 'weight', 'base_margin', 'nrow',
|
||||
'label_lower_bound', 'label_upper_bound', "feature_type", "feature_name")) {
|
||||
stop(
|
||||
"getinfo: name must be one of the following\n",
|
||||
" 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound', 'feature_type', 'feature_name'"
|
||||
)
|
||||
!name %in% allowed_fields) {
|
||||
stop("getinfo: name must be one of the following\n",
|
||||
paste(paste0("'", allowed_fields, "'"), collapse = ", "))
|
||||
}
|
||||
if (name == "feature_name" || name == "feature_type") {
|
||||
if (name == "nrow") {
|
||||
ret <- nrow(object)
|
||||
} else if (name %in% allowed_str_fields) {
|
||||
ret <- .Call(XGDMatrixGetStrFeatureInfo_R, object, name)
|
||||
} else if (name != "nrow") {
|
||||
ret <- .Call(XGDMatrixGetInfo_R, object, name)
|
||||
} else if (name %in% allowed_float_fields) {
|
||||
ret <- .Call(XGDMatrixGetFloatInfo_R, object, name)
|
||||
if (length(ret) > nrow(object)) {
|
||||
ret <- matrix(ret, nrow = nrow(object), byrow = TRUE)
|
||||
}
|
||||
} else if (name %in% allowed_int_fields) {
|
||||
if (name == "group") {
|
||||
name <- "group_ptr"
|
||||
}
|
||||
ret <- .Call(XGDMatrixGetUIntInfo_R, object, name)
|
||||
if (length(ret) > nrow(object)) {
|
||||
ret <- matrix(ret, nrow = nrow(object), byrow = TRUE)
|
||||
}
|
||||
} else {
|
||||
ret <- nrow(object)
|
||||
}
|
||||
if (length(ret) == 0) return(NULL)
|
||||
return(ret)
|
||||
@ -270,13 +334,15 @@ getinfo.xgb.DMatrix <- function(object, name, ...) {
|
||||
#' @param ... other parameters
|
||||
#'
|
||||
#' @details
|
||||
#' The \code{name} field can be one of the following:
|
||||
#' See the documentation for \link{xgb.DMatrix} for possible fields that can be set
|
||||
#' (which correspond to arguments in that function).
|
||||
#'
|
||||
#' \itemize{
|
||||
#' \item \code{label}: label XGBoost learn from ;
|
||||
#' \item \code{weight}: to do a weight rescale ;
|
||||
#' \item \code{base_margin}: base margin is the base prediction XGBoost will boost from ;
|
||||
#' \item \code{group}: number of rows in each group (to use with \code{rank:pairwise} objective).
|
||||
#' Note that the following fields are allowed in the construction of an \code{xgb.DMatrix}
|
||||
#' but \bold{aren't} allowed here:\itemize{
|
||||
#' \item data
|
||||
#' \item missing
|
||||
#' \item silent
|
||||
#' \item nthread
|
||||
#' }
|
||||
#'
|
||||
#' @examples
|
||||
@ -328,6 +394,12 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
|
||||
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info))
|
||||
return(TRUE)
|
||||
}
|
||||
if (name == "qid") {
|
||||
if (NROW(info) != nrow(object))
|
||||
stop("The length of qid assignments must equal to the number of rows in the input data")
|
||||
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info))
|
||||
return(TRUE)
|
||||
}
|
||||
if (name == "feature_weights") {
|
||||
if (length(info) != ncol(object)) {
|
||||
stop("The number of feature weights must equal to the number of columns in the input data")
|
||||
|
||||
@ -23,14 +23,20 @@ Get information of an xgb.DMatrix object
|
||||
The \code{name} field can be one of the following:
|
||||
|
||||
\itemize{
|
||||
\item \code{label}: label XGBoost learn from ;
|
||||
\item \code{weight}: to do a weight rescale ;
|
||||
\item \code{base_margin}: base margin is the base prediction XGBoost will boost from ;
|
||||
\item \code{nrow}: number of rows of the \code{xgb.DMatrix}.
|
||||
|
||||
\item \code{label}
|
||||
\item \code{weight}
|
||||
\item \code{base_margin}
|
||||
\item \code{label_lower_bound}
|
||||
\item \code{label_upper_bound}
|
||||
\item \code{group}
|
||||
\item \code{feature_type}
|
||||
\item \code{feature_name}
|
||||
\item \code{nrow}
|
||||
}
|
||||
See the documentation for \link{xgb.DMatrix} for more information about these fields.
|
||||
|
||||
\code{group} can be setup by \code{setinfo} but can't be retrieved by \code{getinfo}.
|
||||
Note that, while 'qid' cannot be retrieved, it's possible to get the equivalent 'group'
|
||||
for a DMatrix that had 'qid' assigned.
|
||||
}
|
||||
\examples{
|
||||
data(agaricus.train, package='xgboost')
|
||||
|
||||
@ -22,13 +22,15 @@ setinfo(object, ...)
|
||||
Set information of an xgb.DMatrix object
|
||||
}
|
||||
\details{
|
||||
The \code{name} field can be one of the following:
|
||||
See the documentation for \link{xgb.DMatrix} for possible fields that can be set
|
||||
(which correspond to arguments in that function).
|
||||
|
||||
\itemize{
|
||||
\item \code{label}: label XGBoost learn from ;
|
||||
\item \code{weight}: to do a weight rescale ;
|
||||
\item \code{base_margin}: base margin is the base prediction XGBoost will boost from ;
|
||||
\item \code{group}: number of rows in each group (to use with \code{rank:pairwise} objective).
|
||||
Note that the following fields are allowed in the construction of an \code{xgb.DMatrix}
|
||||
but \bold{aren't} allowed here:\itemize{
|
||||
\item data
|
||||
\item missing
|
||||
\item silent
|
||||
\item nthread
|
||||
}
|
||||
}
|
||||
\examples{
|
||||
|
||||
@ -6,11 +6,18 @@
|
||||
\usage{
|
||||
xgb.DMatrix(
|
||||
data,
|
||||
info = list(),
|
||||
label = NULL,
|
||||
weight = NULL,
|
||||
base_margin = NULL,
|
||||
missing = NA,
|
||||
silent = FALSE,
|
||||
feature_names = colnames(data),
|
||||
nthread = NULL,
|
||||
...
|
||||
group = NULL,
|
||||
qid = NULL,
|
||||
label_lower_bound = NULL,
|
||||
label_upper_bound = NULL,
|
||||
feature_weights = NULL
|
||||
)
|
||||
}
|
||||
\arguments{
|
||||
@ -19,17 +26,35 @@ a \code{dgRMatrix} object,
|
||||
a \code{dsparseVector} object (only when making predictions from a fitted model, will be
|
||||
interpreted as a row vector), or a character string representing a filename.}
|
||||
|
||||
\item{info}{a named list of additional information to store in the \code{xgb.DMatrix} object.
|
||||
See \code{\link{setinfo}} for the specific allowed kinds of}
|
||||
\item{label}{Label of the training data.}
|
||||
|
||||
\item{weight}{Weight for each instance.
|
||||
|
||||
Note that, for ranking task, weights are per-group. In ranking task, one weight
|
||||
is assigned to each group (not each data point). This is because we
|
||||
only care about the relative ordering of data points within each group,
|
||||
so it doesn't make sense to assign weights to individual data points.}
|
||||
|
||||
\item{base_margin}{Base margin used for boosting from existing model.}
|
||||
|
||||
\item{missing}{a float value to represents missing values in data (used only when input is a dense matrix).
|
||||
It is useful when a 0 or some other extreme value represents missing values in data.}
|
||||
|
||||
\item{silent}{whether to suppress printing an informational message after loading from a file.}
|
||||
|
||||
\item{feature_names}{Set names for features.}
|
||||
|
||||
\item{nthread}{Number of threads used for creating DMatrix.}
|
||||
|
||||
\item{...}{the \code{info} data could be passed directly as parameters, without creating an \code{info} list.}
|
||||
\item{group}{Group size for all ranking group.}
|
||||
|
||||
\item{qid}{Query ID for data samples, used for ranking.}
|
||||
|
||||
\item{label_lower_bound}{Lower bound for survival training.}
|
||||
|
||||
\item{label_upper_bound}{Upper bound for survival training.}
|
||||
|
||||
\item{feature_weights}{Set feature weights for column sampling.}
|
||||
}
|
||||
\description{
|
||||
Construct xgb.DMatrix object from either a dense matrix, a sparse matrix, or a local file.
|
||||
|
||||
@ -39,7 +39,8 @@ extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
|
||||
extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP);
|
||||
extern SEXP XGDMatrixGetFloatInfo_R(SEXP, SEXP);
|
||||
extern SEXP XGDMatrixGetUIntInfo_R(SEXP, SEXP);
|
||||
extern SEXP XGDMatrixGetStrFeatureInfo_R(SEXP, SEXP);
|
||||
extern SEXP XGDMatrixNumCol_R(SEXP);
|
||||
extern SEXP XGDMatrixNumRow_R(SEXP);
|
||||
@ -76,7 +77,8 @@ static const R_CallMethodDef CallEntries[] = {
|
||||
{"XGDMatrixCreateFromCSR_R", (DL_FUNC) &XGDMatrixCreateFromCSR_R, 6},
|
||||
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
|
||||
{"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3},
|
||||
{"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2},
|
||||
{"XGDMatrixGetFloatInfo_R", (DL_FUNC) &XGDMatrixGetFloatInfo_R, 2},
|
||||
{"XGDMatrixGetUIntInfo_R", (DL_FUNC) &XGDMatrixGetUIntInfo_R, 2},
|
||||
{"XGDMatrixGetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixGetStrFeatureInfo_R, 2},
|
||||
{"XGDMatrixNumCol_R", (DL_FUNC) &XGDMatrixNumCol_R, 1},
|
||||
{"XGDMatrixNumRow_R", (DL_FUNC) &XGDMatrixNumRow_R, 1},
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
@ -412,17 +413,27 @@ XGB_DLL SEXP XGDMatrixGetStrFeatureInfo_R(SEXP handle, SEXP field) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) {
|
||||
XGB_DLL SEXP XGDMatrixGetFloatInfo_R(SEXP handle, SEXP field) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
bst_ulong olen;
|
||||
const float *res;
|
||||
CHECK_CALL(XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), &olen, &res));
|
||||
ret = PROTECT(allocVector(REALSXP, olen));
|
||||
double *ret_ = REAL(ret);
|
||||
for (size_t i = 0; i < olen; ++i) {
|
||||
ret_[i] = res[i];
|
||||
}
|
||||
std::copy(res, res + olen, REAL(ret));
|
||||
R_API_END();
|
||||
UNPROTECT(1);
|
||||
return ret;
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGDMatrixGetUIntInfo_R(SEXP handle, SEXP field) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
bst_ulong olen;
|
||||
const unsigned *res;
|
||||
CHECK_CALL(XGDMatrixGetUIntInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), &olen, &res));
|
||||
ret = PROTECT(allocVector(INTSXP, olen));
|
||||
std::copy(res, res + olen, INTEGER(ret));
|
||||
R_API_END();
|
||||
UNPROTECT(1);
|
||||
return ret;
|
||||
|
||||
@ -106,12 +106,20 @@ XGB_DLL SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent);
|
||||
XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array);
|
||||
|
||||
/*!
|
||||
* \brief get info vector from matrix
|
||||
* \brief get info vector (float type) from matrix
|
||||
* \param handle a instance of data matrix
|
||||
* \param field field name
|
||||
* \return info vector
|
||||
*/
|
||||
XGB_DLL SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field);
|
||||
XGB_DLL SEXP XGDMatrixGetFloatInfo_R(SEXP handle, SEXP field);
|
||||
|
||||
/*!
|
||||
* \brief get info vector (uint type) from matrix
|
||||
* \param handle a instance of data matrix
|
||||
* \param field field name
|
||||
* \return info vector
|
||||
*/
|
||||
XGB_DLL SEXP XGDMatrixGetUIntInfo_R(SEXP handle, SEXP field);
|
||||
|
||||
/*!
|
||||
* \brief return number of rows
|
||||
|
||||
@ -56,7 +56,7 @@ test_that("parameter validation works", {
|
||||
y <- d[, "x1"] + d[, "x2"]^2 +
|
||||
ifelse(d[, "x3"] > .5, d[, "x3"]^2, 2^d[, "x3"]) +
|
||||
rnorm(10)
|
||||
dtrain <- xgb.DMatrix(data = d, info = list(label = y), nthread = n_threads)
|
||||
dtrain <- xgb.DMatrix(data = d, label = y, nthread = n_threads)
|
||||
|
||||
correct <- function() {
|
||||
params <- list(
|
||||
@ -124,7 +124,7 @@ test_that("dart prediction works", {
|
||||
expect_false(all(matrix(pred_by_xgboost_0, byrow = TRUE) == matrix(pred_by_xgboost_2, byrow = TRUE)))
|
||||
|
||||
set.seed(1994)
|
||||
dtrain <- xgb.DMatrix(data = d, info = list(label = y), nthread = n_threads)
|
||||
dtrain <- xgb.DMatrix(data = d, label = y, nthread = n_threads)
|
||||
booster_by_train <- xgb.train(
|
||||
params = list(
|
||||
booster = "dart",
|
||||
@ -186,7 +186,7 @@ test_that("train and predict softprob", {
|
||||
x3 = rnorm(100)
|
||||
)
|
||||
y <- sample.int(10, 100, replace = TRUE) - 1
|
||||
dtrain <- xgb.DMatrix(data = d, info = list(label = y), nthread = n_threads)
|
||||
dtrain <- xgb.DMatrix(data = d, label = y, nthread = n_threads)
|
||||
booster <- xgb.train(
|
||||
params = list(tree_method = "hist", nthread = n_threads),
|
||||
data = dtrain, nrounds = 4, num_class = 10,
|
||||
@ -643,3 +643,28 @@ test_that("Can use multi-output labels with custom objectives", {
|
||||
expect_true(cor(y, pred[, 1]) > 0.9)
|
||||
expect_true(cor(y, pred[, 2]) < -0.9)
|
||||
})
|
||||
|
||||
test_that("Can use ranking objectives with either 'qid' or 'group'", {
|
||||
set.seed(123)
|
||||
x <- matrix(rnorm(100 * 10), nrow = 100)
|
||||
y <- sample(2, size = 100, replace = TRUE) - 1
|
||||
qid <- c(rep(1, 20), rep(2, 20), rep(3, 60))
|
||||
gr <- c(20, 20, 60)
|
||||
|
||||
dmat_qid <- xgb.DMatrix(x, label = y, qid = qid)
|
||||
dmat_gr <- xgb.DMatrix(x, label = y, group = gr)
|
||||
|
||||
params <- list(tree_method = "hist",
|
||||
lambdarank_num_pair_per_sample = 8,
|
||||
objective = "rank:ndcg",
|
||||
lambdarank_pair_method = "topk",
|
||||
nthread = n_threads)
|
||||
set.seed(123)
|
||||
model_qid <- xgb.train(params, dmat_qid, nrounds = 5)
|
||||
set.seed(123)
|
||||
model_gr <- xgb.train(params, dmat_gr, nrounds = 5)
|
||||
|
||||
pred_qid <- predict(model_qid, x)
|
||||
pred_gr <- predict(model_gr, x)
|
||||
expect_equal(pred_qid, pred_gr)
|
||||
})
|
||||
|
||||
@ -305,3 +305,20 @@ test_that("xgb.DMatrix: error on three-dimensional array", {
|
||||
dim(y) <- c(50, 4, 2)
|
||||
expect_error(xgb.DMatrix(data = x, label = y))
|
||||
})
|
||||
|
||||
test_that("xgb.DMatrix: can get group for both 'qid' and 'group' constructors", {
|
||||
set.seed(123)
|
||||
x <- matrix(rnorm(1000), nrow = 100)
|
||||
group <- c(20, 20, 60)
|
||||
qid <- c(rep(1, 20), rep(2, 20), rep(3, 60))
|
||||
|
||||
gr_mat <- xgb.DMatrix(x, group = group)
|
||||
qid_mat <- xgb.DMatrix(x, qid = qid)
|
||||
|
||||
info_gr <- getinfo(gr_mat, "group")
|
||||
info_qid <- getinfo(qid_mat, "group")
|
||||
expect_equal(info_gr, info_qid)
|
||||
|
||||
expected_gr <- c(0, 20, 40, 100)
|
||||
expect_equal(info_gr, expected_gr)
|
||||
})
|
||||
|
||||
@ -47,7 +47,7 @@ test_that("interaction constraints scientific representation", {
|
||||
d <- matrix(rexp(rows, rate = .1), nrow = rows, ncol = cols)
|
||||
y <- rnorm(rows)
|
||||
|
||||
dtrain <- xgb.DMatrix(data = d, info = list(label = y), nthread = n_threads)
|
||||
dtrain <- xgb.DMatrix(data = d, label = y, nthread = n_threads)
|
||||
inc <- list(c(seq.int(from = 0, to = cols, by = 1)))
|
||||
|
||||
with_inc <- xgb.train(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user