[R] simplified the code; parameter style consistency
This commit is contained in:
parent
8473b18c3d
commit
56bd442b31
@ -2,8 +2,8 @@
|
||||
#'
|
||||
#' Save xgb.DMatrix object to binary file
|
||||
#'
|
||||
#' @param DMatrix the DMatrix object
|
||||
#' @param fname the name of the binary file.
|
||||
#' @param dmatrix the \code{xgb.DMatrix} object
|
||||
#' @param fname the name of the file to write.
|
||||
#'
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
@ -12,15 +12,12 @@
|
||||
#' xgb.DMatrix.save(dtrain, 'xgb.DMatrix.data')
|
||||
#' dtrain <- xgb.DMatrix('xgb.DMatrix.data')
|
||||
#' @export
|
||||
xgb.DMatrix.save <- function(DMatrix, fname) {
|
||||
if (typeof(fname) != "character") {
|
||||
stop("xgb.save: fname must be character")
|
||||
}
|
||||
if (class(DMatrix) == "xgb.DMatrix") {
|
||||
.Call("XGDMatrixSaveBinary_R", DMatrix, fname, as.integer(FALSE),
|
||||
PACKAGE = "xgboost")
|
||||
return(TRUE)
|
||||
}
|
||||
stop("xgb.DMatrix.save: the input must be xgb.DMatrix")
|
||||
return(FALSE)
|
||||
xgb.DMatrix.save <- function(dmatrix, fname) {
|
||||
if (typeof(fname) != "character")
|
||||
stop("fname must be character")
|
||||
if (class(dmatrix) != "xgb.DMatrix")
|
||||
stop("the input data must be xgb.DMatrix")
|
||||
|
||||
.Call("XGDMatrixSaveBinary_R", dmatrix, fname, 0L, PACKAGE = "xgboost")
|
||||
return(TRUE)
|
||||
}
|
||||
|
||||
@ -2,11 +2,6 @@
|
||||
#'
|
||||
#' Save a xgboost model to text file. Could be parsed later.
|
||||
#'
|
||||
#' @importFrom magrittr %>%
|
||||
#' @importFrom stringr str_replace
|
||||
#' @importFrom data.table fread
|
||||
#' @importFrom data.table :=
|
||||
#' @importFrom data.table setnames
|
||||
#' @param model the model object.
|
||||
#' @param fname the name of the text file where to save the model text dump. If not provided or set to \code{NULL} the function will return the model as a \code{character} vector.
|
||||
#' @param fmap feature map file representing the type of feature.
|
||||
@ -15,10 +10,11 @@
|
||||
#' See demo/ for walkthrough example in R, and
|
||||
#' \url{https://github.com/dmlc/xgboost/blob/master/demo/data/featmap.txt}
|
||||
#' for example Format.
|
||||
#' @param with.stats whether dump statistics of splits
|
||||
#' @param with_stats whether dump statistics of splits
|
||||
#' When this option is on, the model dump comes with two additional statistics:
|
||||
#' gain is the approximate loss function gain we get in each split;
|
||||
#' cover is the sum of second order gradient in each node.
|
||||
#' @param ... currently not used
|
||||
#'
|
||||
#' @return
|
||||
#' if fname is not provided or set to \code{NULL} the function will return the model as a \code{character} vector. Otherwise it will return \code{TRUE}.
|
||||
@ -28,43 +24,36 @@
|
||||
#' data(agaricus.test, package='xgboost')
|
||||
#' train <- agaricus.train
|
||||
#' test <- agaricus.test
|
||||
#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
|
||||
#' eta = 1, nthread = 2, nround = 2,objective = "binary:logistic")
|
||||
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
|
||||
#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
|
||||
#' # save the model in file 'xgb.model.dump'
|
||||
#' xgb.dump(bst, 'xgb.model.dump', with.stats = TRUE)
|
||||
#' xgb.dump(bst, 'xgb.model.dump', with_stats = TRUE)
|
||||
#'
|
||||
#' # print the model without saving it to a file
|
||||
#' print(xgb.dump(bst))
|
||||
#' @export
|
||||
xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with.stats=FALSE) {
|
||||
if (class(model) != "xgb.Booster") {
|
||||
stop("model: argument must be type xgb.Booster")
|
||||
xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with_stats=FALSE, ...) {
|
||||
check.deprecation(...)
|
||||
if (class(model) != "xgb.Booster")
|
||||
stop("model: argument must be of type xgb.Booster")
|
||||
if (!(class(fname) %in% c("character", "NULL") && length(fname) <= 1))
|
||||
stop("fname: argument must be of type character (when provided)")
|
||||
if (!(class(fmap) %in% c("character", "NULL") && length(fmap) <= 1))
|
||||
stop("fmap: argument must be of type character (when provided)")
|
||||
|
||||
model <- xgb.Booster.check(model)
|
||||
model_dump <- .Call("XGBoosterDumpModel_R", model$handle, fmap, as.integer(with_stats), PACKAGE = "xgboost")
|
||||
|
||||
if (is.null(fname))
|
||||
model_dump <- str_replace_all(model_dump, '\t', '')
|
||||
|
||||
model_dump <- unlist(str_split(model_dump, '\n'))
|
||||
model_dump <- grep('(^$|^0$)', model_dump, invert = TRUE, value = TRUE)
|
||||
|
||||
if (is.null(fname)) {
|
||||
return(model_dump)
|
||||
} else {
|
||||
model <- xgb.Booster.check(model)
|
||||
}
|
||||
if (!(class(fname) %in% c("character", "NULL") && length(fname) <= 1)) {
|
||||
stop("fname: argument must be type character (when provided)")
|
||||
}
|
||||
if (!(class(fmap) %in% c("character", "NULL") && length(fname) <= 1)) {
|
||||
stop("fmap: argument must be type character (when provided)")
|
||||
}
|
||||
|
||||
longString <- .Call("XGBoosterDumpModel_R", model$handle, fmap, as.integer(with.stats), PACKAGE = "xgboost")
|
||||
|
||||
dt <- fread(paste(longString, collapse = ""), sep = "\n", header = F)
|
||||
|
||||
setnames(dt, "Lines")
|
||||
|
||||
if(is.null(fname)) {
|
||||
result <- dt[Lines != "0"][, Lines := str_replace(Lines, "^\t+", "")][Lines != ""][, paste(Lines)]
|
||||
return(result)
|
||||
} else {
|
||||
result <- dt[Lines != "0"][Lines != ""][, paste(Lines)] %>% writeLines(fname)
|
||||
writeLines(model_dump, fname)
|
||||
return(TRUE)
|
||||
}
|
||||
}
|
||||
|
||||
# Avoid error messages during CRAN check.
|
||||
# The reason is that these variables are never declared
|
||||
# They are mainly column names inferred by Data.table...
|
||||
globalVariables(c("Lines", "."))
|
||||
|
||||
@ -3,29 +3,25 @@
|
||||
#' Save xgboost model from xgboost or xgb.train
|
||||
#'
|
||||
#' @param model the model object.
|
||||
#' @param fname the name of the binary file.
|
||||
#' @param fname the name of the file to write.
|
||||
#'
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
#' data(agaricus.test, package='xgboost')
|
||||
#' train <- agaricus.train
|
||||
#' test <- agaricus.test
|
||||
#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
|
||||
#' eta = 1, nthread = 2, nround = 2,objective = "binary:logistic")
|
||||
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
|
||||
#' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
|
||||
#' xgb.save(bst, 'xgb.model')
|
||||
#' bst <- xgb.load('xgb.model')
|
||||
#' pred <- predict(bst, test$data)
|
||||
#' @export
|
||||
xgb.save <- function(model, fname) {
|
||||
if (typeof(fname) != "character") {
|
||||
stop("xgb.save: fname must be character")
|
||||
}
|
||||
if (class(model) == "xgb.Booster") {
|
||||
model <- xgb.Booster.check(model)
|
||||
.Call("XGBoosterSaveModel_R", model$handle, fname, PACKAGE = "xgboost")
|
||||
return(TRUE)
|
||||
}
|
||||
stop("xgb.save: the input must be xgb.Booster. Use xgb.DMatrix.save to save
|
||||
xgb.DMatrix object.")
|
||||
return(FALSE)
|
||||
if (typeof(fname) != "character")
|
||||
stop("fname must be character")
|
||||
if (class(model) != "xgb.Booster")
|
||||
stop("the input must be xgb.Booster. Use xgb.DMatrix.save to save xgb.DMatrix object.")
|
||||
|
||||
.Call("XGBoosterSaveModel_R", model$handle, fname, PACKAGE = "xgboost")
|
||||
return(TRUE)
|
||||
}
|
||||
|
||||
@ -10,20 +10,14 @@
|
||||
#' data(agaricus.test, package='xgboost')
|
||||
#' train <- agaricus.train
|
||||
#' test <- agaricus.test
|
||||
#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
|
||||
#' eta = 1, nthread = 2, nround = 2,objective = "binary:logistic")
|
||||
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
|
||||
#' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
|
||||
#' raw <- xgb.save.raw(bst)
|
||||
#' bst <- xgb.load(raw)
|
||||
#' pred <- predict(bst, test$data)
|
||||
#'
|
||||
#' @export
|
||||
xgb.save.raw <- function(model) {
|
||||
if (class(model) == "xgb.Booster"){
|
||||
model <- model$handle
|
||||
}
|
||||
if (class(model) == "xgb.Booster.handle") {
|
||||
raw <- .Call("XGBoosterModelToRaw_R", model, PACKAGE = "xgboost")
|
||||
return(raw)
|
||||
}
|
||||
stop("xgb.raw: the input must be xgb.Booster.handle. Use xgb.DMatrix.save to save
|
||||
xgb.DMatrix object.")
|
||||
model <- xgb.get.handle(model)
|
||||
.Call("XGBoosterModelToRaw_R", model, PACKAGE = "xgboost")
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user