[R] simplified the code; parameter style consistency
This commit is contained in:
parent
8473b18c3d
commit
56bd442b31
@ -2,8 +2,8 @@
|
|||||||
#'
|
#'
|
||||||
#' Save xgb.DMatrix object to binary file
|
#' Save xgb.DMatrix object to binary file
|
||||||
#'
|
#'
|
||||||
#' @param DMatrix the DMatrix object
|
#' @param dmatrix the \code{xgb.DMatrix} object
|
||||||
#' @param fname the name of the binary file.
|
#' @param fname the name of the file to write.
|
||||||
#'
|
#'
|
||||||
#' @examples
|
#' @examples
|
||||||
#' data(agaricus.train, package='xgboost')
|
#' data(agaricus.train, package='xgboost')
|
||||||
@ -12,15 +12,12 @@
|
|||||||
#' xgb.DMatrix.save(dtrain, 'xgb.DMatrix.data')
|
#' xgb.DMatrix.save(dtrain, 'xgb.DMatrix.data')
|
||||||
#' dtrain <- xgb.DMatrix('xgb.DMatrix.data')
|
#' dtrain <- xgb.DMatrix('xgb.DMatrix.data')
|
||||||
#' @export
|
#' @export
|
||||||
xgb.DMatrix.save <- function(DMatrix, fname) {
|
xgb.DMatrix.save <- function(dmatrix, fname) {
|
||||||
if (typeof(fname) != "character") {
|
if (typeof(fname) != "character")
|
||||||
stop("xgb.save: fname must be character")
|
stop("fname must be character")
|
||||||
}
|
if (class(dmatrix) != "xgb.DMatrix")
|
||||||
if (class(DMatrix) == "xgb.DMatrix") {
|
stop("the input data must be xgb.DMatrix")
|
||||||
.Call("XGDMatrixSaveBinary_R", DMatrix, fname, as.integer(FALSE),
|
|
||||||
PACKAGE = "xgboost")
|
.Call("XGDMatrixSaveBinary_R", dmatrix, fname, 0L, PACKAGE = "xgboost")
|
||||||
return(TRUE)
|
return(TRUE)
|
||||||
}
|
|
||||||
stop("xgb.DMatrix.save: the input must be xgb.DMatrix")
|
|
||||||
return(FALSE)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,11 +2,6 @@
|
|||||||
#'
|
#'
|
||||||
#' Save a xgboost model to text file. Could be parsed later.
|
#' Save a xgboost model to text file. Could be parsed later.
|
||||||
#'
|
#'
|
||||||
#' @importFrom magrittr %>%
|
|
||||||
#' @importFrom stringr str_replace
|
|
||||||
#' @importFrom data.table fread
|
|
||||||
#' @importFrom data.table :=
|
|
||||||
#' @importFrom data.table setnames
|
|
||||||
#' @param model the model object.
|
#' @param model the model object.
|
||||||
#' @param fname the name of the text file where to save the model text dump. If not provided or set to \code{NULL} the function will return the model as a \code{character} vector.
|
#' @param fname the name of the text file where to save the model text dump. If not provided or set to \code{NULL} the function will return the model as a \code{character} vector.
|
||||||
#' @param fmap feature map file representing the type of feature.
|
#' @param fmap feature map file representing the type of feature.
|
||||||
@ -15,10 +10,11 @@
|
|||||||
#' See demo/ for walkthrough example in R, and
|
#' See demo/ for walkthrough example in R, and
|
||||||
#' \url{https://github.com/dmlc/xgboost/blob/master/demo/data/featmap.txt}
|
#' \url{https://github.com/dmlc/xgboost/blob/master/demo/data/featmap.txt}
|
||||||
#' for example Format.
|
#' for example Format.
|
||||||
#' @param with.stats whether dump statistics of splits
|
#' @param with_stats whether dump statistics of splits
|
||||||
#' When this option is on, the model dump comes with two additional statistics:
|
#' When this option is on, the model dump comes with two additional statistics:
|
||||||
#' gain is the approximate loss function gain we get in each split;
|
#' gain is the approximate loss function gain we get in each split;
|
||||||
#' cover is the sum of second order gradient in each node.
|
#' cover is the sum of second order gradient in each node.
|
||||||
|
#' @param ... currently not used
|
||||||
#'
|
#'
|
||||||
#' @return
|
#' @return
|
||||||
#' if fname is not provided or set to \code{NULL} the function will return the model as a \code{character} vector. Otherwise it will return \code{TRUE}.
|
#' if fname is not provided or set to \code{NULL} the function will return the model as a \code{character} vector. Otherwise it will return \code{TRUE}.
|
||||||
@ -28,43 +24,36 @@
|
|||||||
#' data(agaricus.test, package='xgboost')
|
#' data(agaricus.test, package='xgboost')
|
||||||
#' train <- agaricus.train
|
#' train <- agaricus.train
|
||||||
#' test <- agaricus.test
|
#' test <- agaricus.test
|
||||||
#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
|
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
|
||||||
#' eta = 1, nthread = 2, nround = 2,objective = "binary:logistic")
|
#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
|
||||||
#' # save the model in file 'xgb.model.dump'
|
#' # save the model in file 'xgb.model.dump'
|
||||||
#' xgb.dump(bst, 'xgb.model.dump', with.stats = TRUE)
|
#' xgb.dump(bst, 'xgb.model.dump', with_stats = TRUE)
|
||||||
#'
|
#'
|
||||||
#' # print the model without saving it to a file
|
#' # print the model without saving it to a file
|
||||||
#' print(xgb.dump(bst))
|
#' print(xgb.dump(bst))
|
||||||
#' @export
|
#' @export
|
||||||
xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with.stats=FALSE) {
|
xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with_stats=FALSE, ...) {
|
||||||
if (class(model) != "xgb.Booster") {
|
check.deprecation(...)
|
||||||
stop("model: argument must be type xgb.Booster")
|
if (class(model) != "xgb.Booster")
|
||||||
|
stop("model: argument must be of type xgb.Booster")
|
||||||
|
if (!(class(fname) %in% c("character", "NULL") && length(fname) <= 1))
|
||||||
|
stop("fname: argument must be of type character (when provided)")
|
||||||
|
if (!(class(fmap) %in% c("character", "NULL") && length(fmap) <= 1))
|
||||||
|
stop("fmap: argument must be of type character (when provided)")
|
||||||
|
|
||||||
|
model <- xgb.Booster.check(model)
|
||||||
|
model_dump <- .Call("XGBoosterDumpModel_R", model$handle, fmap, as.integer(with_stats), PACKAGE = "xgboost")
|
||||||
|
|
||||||
|
if (is.null(fname))
|
||||||
|
model_dump <- str_replace_all(model_dump, '\t', '')
|
||||||
|
|
||||||
|
model_dump <- unlist(str_split(model_dump, '\n'))
|
||||||
|
model_dump <- grep('(^$|^0$)', model_dump, invert = TRUE, value = TRUE)
|
||||||
|
|
||||||
|
if (is.null(fname)) {
|
||||||
|
return(model_dump)
|
||||||
} else {
|
} else {
|
||||||
model <- xgb.Booster.check(model)
|
writeLines(model_dump, fname)
|
||||||
}
|
|
||||||
if (!(class(fname) %in% c("character", "NULL") && length(fname) <= 1)) {
|
|
||||||
stop("fname: argument must be type character (when provided)")
|
|
||||||
}
|
|
||||||
if (!(class(fmap) %in% c("character", "NULL") && length(fname) <= 1)) {
|
|
||||||
stop("fmap: argument must be type character (when provided)")
|
|
||||||
}
|
|
||||||
|
|
||||||
longString <- .Call("XGBoosterDumpModel_R", model$handle, fmap, as.integer(with.stats), PACKAGE = "xgboost")
|
|
||||||
|
|
||||||
dt <- fread(paste(longString, collapse = ""), sep = "\n", header = F)
|
|
||||||
|
|
||||||
setnames(dt, "Lines")
|
|
||||||
|
|
||||||
if(is.null(fname)) {
|
|
||||||
result <- dt[Lines != "0"][, Lines := str_replace(Lines, "^\t+", "")][Lines != ""][, paste(Lines)]
|
|
||||||
return(result)
|
|
||||||
} else {
|
|
||||||
result <- dt[Lines != "0"][Lines != ""][, paste(Lines)] %>% writeLines(fname)
|
|
||||||
return(TRUE)
|
return(TRUE)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Avoid error messages during CRAN check.
|
|
||||||
# The reason is that these variables are never declared
|
|
||||||
# They are mainly column names inferred by Data.table...
|
|
||||||
globalVariables(c("Lines", "."))
|
|
||||||
|
|||||||
@ -3,29 +3,25 @@
|
|||||||
#' Save xgboost model from xgboost or xgb.train
|
#' Save xgboost model from xgboost or xgb.train
|
||||||
#'
|
#'
|
||||||
#' @param model the model object.
|
#' @param model the model object.
|
||||||
#' @param fname the name of the binary file.
|
#' @param fname the name of the file to write.
|
||||||
#'
|
#'
|
||||||
#' @examples
|
#' @examples
|
||||||
#' data(agaricus.train, package='xgboost')
|
#' data(agaricus.train, package='xgboost')
|
||||||
#' data(agaricus.test, package='xgboost')
|
#' data(agaricus.test, package='xgboost')
|
||||||
#' train <- agaricus.train
|
#' train <- agaricus.train
|
||||||
#' test <- agaricus.test
|
#' test <- agaricus.test
|
||||||
#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
|
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
|
||||||
#' eta = 1, nthread = 2, nround = 2,objective = "binary:logistic")
|
#' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
|
||||||
#' xgb.save(bst, 'xgb.model')
|
#' xgb.save(bst, 'xgb.model')
|
||||||
#' bst <- xgb.load('xgb.model')
|
#' bst <- xgb.load('xgb.model')
|
||||||
#' pred <- predict(bst, test$data)
|
#' pred <- predict(bst, test$data)
|
||||||
#' @export
|
#' @export
|
||||||
xgb.save <- function(model, fname) {
|
xgb.save <- function(model, fname) {
|
||||||
if (typeof(fname) != "character") {
|
if (typeof(fname) != "character")
|
||||||
stop("xgb.save: fname must be character")
|
stop("fname must be character")
|
||||||
}
|
if (class(model) != "xgb.Booster")
|
||||||
if (class(model) == "xgb.Booster") {
|
stop("the input must be xgb.Booster. Use xgb.DMatrix.save to save xgb.DMatrix object.")
|
||||||
model <- xgb.Booster.check(model)
|
|
||||||
.Call("XGBoosterSaveModel_R", model$handle, fname, PACKAGE = "xgboost")
|
.Call("XGBoosterSaveModel_R", model$handle, fname, PACKAGE = "xgboost")
|
||||||
return(TRUE)
|
return(TRUE)
|
||||||
}
|
|
||||||
stop("xgb.save: the input must be xgb.Booster. Use xgb.DMatrix.save to save
|
|
||||||
xgb.DMatrix object.")
|
|
||||||
return(FALSE)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -10,20 +10,14 @@
|
|||||||
#' data(agaricus.test, package='xgboost')
|
#' data(agaricus.test, package='xgboost')
|
||||||
#' train <- agaricus.train
|
#' train <- agaricus.train
|
||||||
#' test <- agaricus.test
|
#' test <- agaricus.test
|
||||||
#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
|
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
|
||||||
#' eta = 1, nthread = 2, nround = 2,objective = "binary:logistic")
|
#' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
|
||||||
#' raw <- xgb.save.raw(bst)
|
#' raw <- xgb.save.raw(bst)
|
||||||
#' bst <- xgb.load(raw)
|
#' bst <- xgb.load(raw)
|
||||||
#' pred <- predict(bst, test$data)
|
#' pred <- predict(bst, test$data)
|
||||||
|
#'
|
||||||
#' @export
|
#' @export
|
||||||
xgb.save.raw <- function(model) {
|
xgb.save.raw <- function(model) {
|
||||||
if (class(model) == "xgb.Booster"){
|
model <- xgb.get.handle(model)
|
||||||
model <- model$handle
|
.Call("XGBoosterModelToRaw_R", model, PACKAGE = "xgboost")
|
||||||
}
|
|
||||||
if (class(model) == "xgb.Booster.handle") {
|
|
||||||
raw <- .Call("XGBoosterModelToRaw_R", model, PACKAGE = "xgboost")
|
|
||||||
return(raw)
|
|
||||||
}
|
|
||||||
stop("xgb.raw: the input must be xgb.Booster.handle. Use xgb.DMatrix.save to save
|
|
||||||
xgb.DMatrix object.")
|
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user