refactor dump function to adapt to the new possibilities of exporting a String

This commit is contained in:
El Potaeto
2015-01-09 00:14:01 +01:00
parent 6fd8bbe71a
commit 3e1eea0eea
4 changed files with 21 additions and 9 deletions

View File

@@ -2,8 +2,11 @@
#'
#' Save a xgboost model to text file. Could be parsed later.
#'
#' @importFrom magrittr %>%
#' @importFrom stringr str_split
#' @importFrom stringr str_replace_all
#' @param model the model object.
#' @param fname the name of the binary file.
#' @param fname the name of the text file where to save the model. If not provided or set to \code{NULL} the function will return the model as a \code{character} vector.
#' @param fmap feature map file representing the type of feature.
#' Detailed description could be found at
#' \url{https://github.com/tqchen/xgboost/wiki/Binary-Classification#dump-model}.
@@ -15,6 +18,9 @@
#' gain is the approximate loss function gain we get in each split;
#' cover is the sum of second order gradient in each node.
#'
#' @return
#' if fname is not provided or set to \code{NULL} the function will return the model as a \code{character} vector. Otherwise it will return \code{TRUE}.
#'
#' @examples
#' data(agaricus.train, package='xgboost')
#' data(agaricus.test, package='xgboost')
@@ -25,15 +31,17 @@
#' xgb.dump(bst, 'xgb.model.dump')
#' @export
#'
xgb.dump <- function(model, fname, fmap = "", with.stats=FALSE) {
xgb.dump <- function(model, fname = NULL, fmap = "", with.stats=FALSE) {
if (class(model) != "xgb.Booster") {
stop("xgb.dump: first argument must be type xgb.Booster")
}
if (typeof(fname) != "character") {
stop("xgb.dump: second argument must be type character")
if (!class(fname) %in% c("character", "NULL")) {
stop("xgb.dump: second argument must be type character if provided")
}
result <- .Call("XGBoosterDumpModel_R", model, fmap, as.integer(with.stats), PACKAGE = "xgboost")
if(is.null(fname)) return(str_split(result, "\n") %>% unlist %>% str_replace_all("\t"," ") %>% Filter(function(x) x != "", .))
writeLines(result, fname)
#unlist(str_split(a, "\n"))==""
return(TRUE)
TRUE
}