From 56bd442b316ca87e50854c98788b7675814a8d7d Mon Sep 17 00:00:00 2001 From: Vadim Khotilovich Date: Mon, 27 Jun 2016 01:57:57 -0500 Subject: [PATCH] [R] simplified the code; parameter style consistency --- R-package/R/xgb.DMatrix.save.R | 23 ++++++------- R-package/R/xgb.dump.R | 63 ++++++++++++++-------------------- R-package/R/xgb.save.R | 24 ++++++------- R-package/R/xgb.save.raw.R | 16 +++------ 4 files changed, 51 insertions(+), 75 deletions(-) diff --git a/R-package/R/xgb.DMatrix.save.R b/R-package/R/xgb.DMatrix.save.R index 63a0be691..9ceec801a 100644 --- a/R-package/R/xgb.DMatrix.save.R +++ b/R-package/R/xgb.DMatrix.save.R @@ -2,8 +2,8 @@ #' #' Save xgb.DMatrix object to binary file #' -#' @param DMatrix the DMatrix object -#' @param fname the name of the binary file. +#' @param dmatrix the \code{xgb.DMatrix} object +#' @param fname the name of the file to write. #' #' @examples #' data(agaricus.train, package='xgboost') @@ -12,15 +12,12 @@ #' xgb.DMatrix.save(dtrain, 'xgb.DMatrix.data') #' dtrain <- xgb.DMatrix('xgb.DMatrix.data') #' @export -xgb.DMatrix.save <- function(DMatrix, fname) { - if (typeof(fname) != "character") { - stop("xgb.save: fname must be character") - } - if (class(DMatrix) == "xgb.DMatrix") { - .Call("XGDMatrixSaveBinary_R", DMatrix, fname, as.integer(FALSE), - PACKAGE = "xgboost") - return(TRUE) - } - stop("xgb.DMatrix.save: the input must be xgb.DMatrix") - return(FALSE) +xgb.DMatrix.save <- function(dmatrix, fname) { + if (typeof(fname) != "character") + stop("fname must be character") + if (class(dmatrix) != "xgb.DMatrix") + stop("the input data must be xgb.DMatrix") + + .Call("XGDMatrixSaveBinary_R", dmatrix, fname, 0L, PACKAGE = "xgboost") + return(TRUE) } diff --git a/R-package/R/xgb.dump.R b/R-package/R/xgb.dump.R index b39359abd..ce8c8696e 100644 --- a/R-package/R/xgb.dump.R +++ b/R-package/R/xgb.dump.R @@ -2,11 +2,6 @@ #' #' Save a xgboost model to text file. Could be parsed later. #' -#' @importFrom magrittr %>% -#' @importFrom stringr str_replace -#' @importFrom data.table fread -#' @importFrom data.table := -#' @importFrom data.table setnames #' @param model the model object. #' @param fname the name of the text file where to save the model text dump. If not provided or set to \code{NULL} the function will return the model as a \code{character} vector. #' @param fmap feature map file representing the type of feature. @@ -15,10 +10,11 @@ #' See demo/ for walkthrough example in R, and #' \url{https://github.com/dmlc/xgboost/blob/master/demo/data/featmap.txt} #' for example Format. -#' @param with.stats whether dump statistics of splits +#' @param with_stats whether dump statistics of splits #' When this option is on, the model dump comes with two additional statistics: #' gain is the approximate loss function gain we get in each split; #' cover is the sum of second order gradient in each node. +#' @param ... currently not used #' #' @return #' if fname is not provided or set to \code{NULL} the function will return the model as a \code{character} vector. Otherwise it will return \code{TRUE}. @@ -28,43 +24,36 @@ #' data(agaricus.test, package='xgboost') #' train <- agaricus.train #' test <- agaricus.test -#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2, -#' eta = 1, nthread = 2, nround = 2,objective = "binary:logistic") +#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2, +#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic") #' # save the model in file 'xgb.model.dump' -#' xgb.dump(bst, 'xgb.model.dump', with.stats = TRUE) +#' xgb.dump(bst, 'xgb.model.dump', with_stats = TRUE) #' #' # print the model without saving it to a file #' print(xgb.dump(bst)) #' @export -xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with.stats=FALSE) { - if (class(model) != "xgb.Booster") { - stop("model: argument must be type xgb.Booster") +xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with_stats=FALSE, ...) { + check.deprecation(...) + if (class(model) != "xgb.Booster") + stop("model: argument must be of type xgb.Booster") + if (!(class(fname) %in% c("character", "NULL") && length(fname) <= 1)) + stop("fname: argument must be of type character (when provided)") + if (!(class(fmap) %in% c("character", "NULL") && length(fmap) <= 1)) + stop("fmap: argument must be of type character (when provided)") + + model <- xgb.Booster.check(model) + model_dump <- .Call("XGBoosterDumpModel_R", model$handle, fmap, as.integer(with_stats), PACKAGE = "xgboost") + + if (is.null(fname)) + model_dump <- str_replace_all(model_dump, '\t', '') + + model_dump <- unlist(str_split(model_dump, '\n')) + model_dump <- grep('(^$|^0$)', model_dump, invert = TRUE, value = TRUE) + + if (is.null(fname)) { + return(model_dump) } else { - model <- xgb.Booster.check(model) - } - if (!(class(fname) %in% c("character", "NULL") && length(fname) <= 1)) { - stop("fname: argument must be type character (when provided)") - } - if (!(class(fmap) %in% c("character", "NULL") && length(fname) <= 1)) { - stop("fmap: argument must be type character (when provided)") - } - - longString <- .Call("XGBoosterDumpModel_R", model$handle, fmap, as.integer(with.stats), PACKAGE = "xgboost") - - dt <- fread(paste(longString, collapse = ""), sep = "\n", header = F) - - setnames(dt, "Lines") - - if(is.null(fname)) { - result <- dt[Lines != "0"][, Lines := str_replace(Lines, "^\t+", "")][Lines != ""][, paste(Lines)] - return(result) - } else { - result <- dt[Lines != "0"][Lines != ""][, paste(Lines)] %>% writeLines(fname) + writeLines(model_dump, fname) return(TRUE) } } - -# Avoid error messages during CRAN check. -# The reason is that these variables are never declared -# They are mainly column names inferred by Data.table... -globalVariables(c("Lines", ".")) diff --git a/R-package/R/xgb.save.R b/R-package/R/xgb.save.R index 7d595ddc6..5b2421b7f 100644 --- a/R-package/R/xgb.save.R +++ b/R-package/R/xgb.save.R @@ -3,29 +3,25 @@ #' Save xgboost model from xgboost or xgb.train #' #' @param model the model object. -#' @param fname the name of the binary file. +#' @param fname the name of the file to write. #' #' @examples #' data(agaricus.train, package='xgboost') #' data(agaricus.test, package='xgboost') #' train <- agaricus.train #' test <- agaricus.test -#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2, -#' eta = 1, nthread = 2, nround = 2,objective = "binary:logistic") +#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2, +#' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic") #' xgb.save(bst, 'xgb.model') #' bst <- xgb.load('xgb.model') #' pred <- predict(bst, test$data) #' @export xgb.save <- function(model, fname) { - if (typeof(fname) != "character") { - stop("xgb.save: fname must be character") - } - if (class(model) == "xgb.Booster") { - model <- xgb.Booster.check(model) - .Call("XGBoosterSaveModel_R", model$handle, fname, PACKAGE = "xgboost") - return(TRUE) - } - stop("xgb.save: the input must be xgb.Booster. Use xgb.DMatrix.save to save - xgb.DMatrix object.") - return(FALSE) + if (typeof(fname) != "character") + stop("fname must be character") + if (class(model) != "xgb.Booster") + stop("the input must be xgb.Booster. Use xgb.DMatrix.save to save xgb.DMatrix object.") + + .Call("XGBoosterSaveModel_R", model$handle, fname, PACKAGE = "xgboost") + return(TRUE) } diff --git a/R-package/R/xgb.save.raw.R b/R-package/R/xgb.save.raw.R index e61303add..1743b67d7 100644 --- a/R-package/R/xgb.save.raw.R +++ b/R-package/R/xgb.save.raw.R @@ -10,20 +10,14 @@ #' data(agaricus.test, package='xgboost') #' train <- agaricus.train #' test <- agaricus.test -#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2, -#' eta = 1, nthread = 2, nround = 2,objective = "binary:logistic") +#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2, +#' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic") #' raw <- xgb.save.raw(bst) #' bst <- xgb.load(raw) #' pred <- predict(bst, test$data) +#' #' @export xgb.save.raw <- function(model) { - if (class(model) == "xgb.Booster"){ - model <- model$handle - } - if (class(model) == "xgb.Booster.handle") { - raw <- .Call("XGBoosterModelToRaw_R", model, PACKAGE = "xgboost") - return(raw) - } - stop("xgb.raw: the input must be xgb.Booster.handle. Use xgb.DMatrix.save to save - xgb.DMatrix object.") + model <- xgb.get.handle(model) + .Call("XGBoosterModelToRaw_R", model, PACKAGE = "xgboost") }