[R-package] JSON dump format and a couple of bugfixes (#1855)

* [R-package] JSON tree dump interface

* [R-package] precision bugfix in xgb.attributes

* [R-package] bugfix for cb.early.stop called from xgb.cv

* [R-package] a bit more clarity on labels checking in xgb.cv

* [R-package] test JSON dump for gblinear as well

* whitespace lint
This commit is contained in:
Vadim Khotilovich
2016-12-11 12:48:39 -06:00
committed by Tianqi Chen
parent 0268dedeea
commit b21e658a02
10 changed files with 72 additions and 22 deletions

View File

@@ -229,7 +229,7 @@ cb.reset.parameters <- function(new_params) {
xgb.parameters(env$bst$handle) <- pars
} else {
for (fd in env$bst_folds)
xgb.parameters(fd$bst$handle) <- pars
xgb.parameters(fd$bst) <- pars
}
}
attr(callback, 'is_pre_iteration') <- TRUE

View File

@@ -339,7 +339,7 @@ xgb.attributes <- function(object) {
# Q: should we warn a user about non-scalar elements?
a <- lapply(a, function(x) {
if (is.null(x)) return(NULL)
if (is.numeric(value[1])) {
if (is.numeric(x[1])) {
format(x[1], digits = 17)
} else {
as.character(x[1])

View File

@@ -16,10 +16,10 @@
#'
#' See \code{\link{xgb.train}} for further details.
#' See also demo/ for walkthrough example in R.
#' @param data takes an \code{xgb.DMatrix} or \code{Matrix} as the input.
#' @param data takes an \code{xgb.DMatrix}, \code{matrix}, or \code{dgCMatrix} as the input.
#' @param nrounds the max number of iterations
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples.
#' @param label vector of response values. Should be provided only when data is \code{DMatrix}.
#' @param label vector of response values. Should be provided only when data is an R-matrix.
#' @param missing is only used when input is a dense matrix. By default is set to NA, which means
#' that NA values should be considered as 'missing' by the algorithm.
#' Sometimes, 0 or other extreme value might be used to represent missing values.
@@ -129,10 +129,9 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
#if (is.null(params[['eval_metric']]) && is.null(feval))
# stop("Either 'eval_metric' or 'feval' must be provided for CV")
# Labels
if (class(data) == 'xgb.DMatrix')
labels <- getinfo(data, 'label')
if (is.null(labels))
# Check the labels
if ( (class(data) == 'xgb.DMatrix' && is.null(getinfo(data, 'label'))) ||
(class(data) != 'xgb.DMatrix' && is.null(label)))
stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix")
# CV folds

View File

@@ -14,6 +14,7 @@
#' When this option is on, the model dump comes with two additional statistics:
#' gain is the approximate loss function gain we get in each split;
#' cover is the sum of second order gradient in each node.
#' @param dump_fomat either 'text' or 'json' format could be specified.
#' @param ... currently not used
#'
#' @return
@@ -30,10 +31,15 @@
#' xgb.dump(bst, 'xgb.model.dump', with_stats = TRUE)
#'
#' # print the model without saving it to a file
#' print(xgb.dump(bst))
#' print(xgb.dump(bst, with_stats = TRUE))
#'
#' # print in JSON format:
#' cat(xgb.dump(bst, with_stats = TRUE, dump_format='json'))
#'
#' @export
xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with_stats=FALSE, ...) {
xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with_stats=FALSE, dump_format = c("text", "json"), ...) {
check.deprecation(...)
dump_format <- match.arg(dump_format)
if (class(model) != "xgb.Booster")
stop("model: argument must be of type xgb.Booster")
if (!(class(fname) %in% c("character", "NULL") && length(fname) <= 1))
@@ -42,12 +48,15 @@ xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with_stats=FALSE, ..
stop("fmap: argument must be of type character (when provided)")
model <- xgb.Booster.check(model)
model_dump <- .Call("XGBoosterDumpModel_R", model$handle, fmap, as.integer(with_stats), PACKAGE = "xgboost")
model_dump <- .Call("XGBoosterDumpModel_R", model$handle, fmap, as.integer(with_stats),
as.character(dump_format), PACKAGE = "xgboost")
if (is.null(fname))
model_dump <- stri_replace_all_regex(model_dump, '\t', '')
model_dump <- unlist(stri_split_regex(model_dump, '\n'))
if (dump_format == "text")
model_dump <- unlist(stri_split_regex(model_dump, '\n'))
model_dump <- grep('^\\s*$', model_dump, invert = TRUE, value = TRUE)
if (is.null(fname)) {