add handle and raw structure to xgb.Booster
This commit is contained in:
parent
ea5860d574
commit
5b611c355e
@ -1,4 +1,7 @@
|
|||||||
setClass("xgb.Booster")
|
setClass("xgb.Booster.handle")
|
||||||
|
setClass("xgb.Booster",
|
||||||
|
slots = c(handle = "xgb.Booster.handle",
|
||||||
|
raw = "raw"))
|
||||||
|
|
||||||
#' Predict method for eXtreme Gradient Boosting model
|
#' Predict method for eXtreme Gradient Boosting model
|
||||||
#'
|
#'
|
||||||
@ -30,6 +33,16 @@ setClass("xgb.Booster")
|
|||||||
setMethod("predict", signature = "xgb.Booster",
|
setMethod("predict", signature = "xgb.Booster",
|
||||||
definition = function(object, newdata, missing = NULL,
|
definition = function(object, newdata, missing = NULL,
|
||||||
outputmargin = FALSE, ntreelimit = NULL, predleaf = FALSE) {
|
outputmargin = FALSE, ntreelimit = NULL, predleaf = FALSE) {
|
||||||
|
if (class(object) != "xgb.Booster"){
|
||||||
|
stop("predict: model in prediction must be of class xgb.Booster")
|
||||||
|
} else {
|
||||||
|
if (is.null(object$handle)) {
|
||||||
|
object$handle <- xgb.load(object$raw)
|
||||||
|
} else {
|
||||||
|
if (is.null(object$raw))
|
||||||
|
object$raw <- xgb.save.raw(object$handle)
|
||||||
|
}
|
||||||
|
}
|
||||||
if (class(newdata) != "xgb.DMatrix") {
|
if (class(newdata) != "xgb.DMatrix") {
|
||||||
if (is.null(missing)) {
|
if (is.null(missing)) {
|
||||||
newdata <- xgb.DMatrix(newdata)
|
newdata <- xgb.DMatrix(newdata)
|
||||||
@ -51,7 +64,7 @@ setMethod("predict", signature = "xgb.Booster",
|
|||||||
if (predleaf) {
|
if (predleaf) {
|
||||||
option <- option + 2
|
option <- option + 2
|
||||||
}
|
}
|
||||||
ret <- .Call("XGBoosterPredict_R", object, newdata, as.integer(option),
|
ret <- .Call("XGBoosterPredict_R", object$handle, newdata, as.integer(option),
|
||||||
as.integer(ntreelimit), PACKAGE = "xgboost")
|
as.integer(ntreelimit), PACKAGE = "xgboost")
|
||||||
if (predleaf){
|
if (predleaf){
|
||||||
len <- getinfo(newdata, "nrow")
|
len <- getinfo(newdata, "nrow")
|
||||||
|
|||||||
@ -65,7 +65,7 @@ xgb.Booster <- function(params = list(), cachelist = list(), modelfile = NULL) {
|
|||||||
stop("xgb.Booster: modelfile must be character or raw vector")
|
stop("xgb.Booster: modelfile must be character or raw vector")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return(structure(handle, class = "xgb.Booster"))
|
return(structure(handle, class = "xgb.Booster.handle"))
|
||||||
}
|
}
|
||||||
|
|
||||||
## ----the following are low level iteratively function, not needed if
|
## ----the following are low level iteratively function, not needed if
|
||||||
@ -102,7 +102,7 @@ xgb.numrow <- function(dmat) {
|
|||||||
}
|
}
|
||||||
# iteratively update booster with customized statistics
|
# iteratively update booster with customized statistics
|
||||||
xgb.iter.boost <- function(booster, dtrain, gpair) {
|
xgb.iter.boost <- function(booster, dtrain, gpair) {
|
||||||
if (class(booster) != "xgb.Booster") {
|
if (class(booster) != "xgb.Booster.handle") {
|
||||||
stop("xgb.iter.update: first argument must be type xgb.Booster")
|
stop("xgb.iter.update: first argument must be type xgb.Booster")
|
||||||
}
|
}
|
||||||
if (class(dtrain) != "xgb.DMatrix") {
|
if (class(dtrain) != "xgb.DMatrix") {
|
||||||
@ -115,7 +115,7 @@ xgb.iter.boost <- function(booster, dtrain, gpair) {
|
|||||||
|
|
||||||
# iteratively update booster with dtrain
|
# iteratively update booster with dtrain
|
||||||
xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) {
|
xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) {
|
||||||
if (class(booster) != "xgb.Booster") {
|
if (class(booster) != "xgb.Booster.handle") {
|
||||||
stop("xgb.iter.update: first argument must be type xgb.Booster")
|
stop("xgb.iter.update: first argument must be type xgb.Booster")
|
||||||
}
|
}
|
||||||
if (class(dtrain) != "xgb.DMatrix") {
|
if (class(dtrain) != "xgb.DMatrix") {
|
||||||
@ -135,7 +135,7 @@ xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) {
|
|||||||
|
|
||||||
# iteratively evaluate one iteration
|
# iteratively evaluate one iteration
|
||||||
xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL, prediction = FALSE) {
|
xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL, prediction = FALSE) {
|
||||||
if (class(booster) != "xgb.Booster") {
|
if (class(booster) != "xgb.Booster.handle") {
|
||||||
stop("xgb.eval: first argument must be type xgb.Booster")
|
stop("xgb.eval: first argument must be type xgb.Booster")
|
||||||
}
|
}
|
||||||
if (typeof(watchlist) != "list") {
|
if (typeof(watchlist) != "list") {
|
||||||
|
|||||||
@ -40,6 +40,13 @@
|
|||||||
xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with.stats=FALSE) {
|
xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with.stats=FALSE) {
|
||||||
if (class(model) != "xgb.Booster") {
|
if (class(model) != "xgb.Booster") {
|
||||||
stop("model: argument must be type xgb.Booster")
|
stop("model: argument must be type xgb.Booster")
|
||||||
|
} else {
|
||||||
|
if (is.null(model$handle)) {
|
||||||
|
model$handle <- xgb.load(model$raw)
|
||||||
|
} else {
|
||||||
|
if (is.null(model$raw))
|
||||||
|
model$raw <- xgb.save.raw(model$handle)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (!(class(fname) %in% c("character", "NULL") && length(fname) <= 1)) {
|
if (!(class(fname) %in% c("character", "NULL") && length(fname) <= 1)) {
|
||||||
stop("fname: argument must be type character (when provided)")
|
stop("fname: argument must be type character (when provided)")
|
||||||
@ -48,7 +55,7 @@ xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with.stats=FALSE) {
|
|||||||
stop("fmap: argument must be type character (when provided)")
|
stop("fmap: argument must be type character (when provided)")
|
||||||
}
|
}
|
||||||
|
|
||||||
longString <- .Call("XGBoosterDumpModel_R", model, fmap, as.integer(with.stats), PACKAGE = "xgboost")
|
longString <- .Call("XGBoosterDumpModel_R", model$handle, fmap, as.integer(with.stats), PACKAGE = "xgboost")
|
||||||
|
|
||||||
dt <- fread(paste(longString, collapse = ""), sep = "\n", header = F)
|
dt <- fread(paste(longString, collapse = ""), sep = "\n", header = F)
|
||||||
|
|
||||||
|
|||||||
@ -19,5 +19,9 @@
|
|||||||
xgb.load <- function(modelfile) {
|
xgb.load <- function(modelfile) {
|
||||||
if (is.null(modelfile))
|
if (is.null(modelfile))
|
||||||
stop("xgb.load: modelfile cannot be NULL")
|
stop("xgb.load: modelfile cannot be NULL")
|
||||||
xgb.Booster(modelfile = modelfile)
|
bst <- list(handle = NULL,raw = NULL)
|
||||||
|
class(bst) <- 'xgb.Booster'
|
||||||
|
bst$handle <- xgb.Booster(modelfile = modelfile)
|
||||||
|
bst$raw <- xgb.save.raw(bst$handle)
|
||||||
|
return(bst)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -22,7 +22,10 @@ xgb.save <- function(model, fname) {
|
|||||||
stop("xgb.save: fname must be character")
|
stop("xgb.save: fname must be character")
|
||||||
}
|
}
|
||||||
if (class(model) == "xgb.Booster") {
|
if (class(model) == "xgb.Booster") {
|
||||||
.Call("XGBoosterSaveModel_R", model, fname, PACKAGE = "xgboost")
|
if (is.null(model$handle)) {
|
||||||
|
model$handle <- xgb.load(model$raw)
|
||||||
|
}
|
||||||
|
.Call("XGBoosterSaveModel_R", model$handle, fname, PACKAGE = "xgboost")
|
||||||
return(TRUE)
|
return(TRUE)
|
||||||
}
|
}
|
||||||
stop("xgb.save: the input must be xgb.Booster. Use xgb.DMatrix.save to save
|
stop("xgb.save: the input must be xgb.Booster. Use xgb.DMatrix.save to save
|
||||||
|
|||||||
@ -17,11 +17,11 @@
|
|||||||
#' pred <- predict(bst, test$data)
|
#' pred <- predict(bst, test$data)
|
||||||
#' @export
|
#' @export
|
||||||
#'
|
#'
|
||||||
xgb.save.raw <- function(model) {
|
xgb.save.raw <- function(handle) {
|
||||||
if (class(model) == "xgb.Booster") {
|
if (class(handle) == "xgb.Booster.handle") {
|
||||||
raw <- .Call("XGBoosterModelToRaw_R", model, PACKAGE = "xgboost")
|
raw <- .Call("XGBoosterModelToRaw_R", handle, PACKAGE = "xgboost")
|
||||||
return(raw)
|
return(raw)
|
||||||
}
|
}
|
||||||
stop("xgb.raw: the input must be xgb.Booster. Use xgb.DMatrix.save to save
|
stop("xgb.raw: the input must be xgb.Booster.handle. Use xgb.DMatrix.save to save
|
||||||
xgb.DMatrix object.")
|
xgb.DMatrix object.")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -86,13 +86,16 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
|
|||||||
}
|
}
|
||||||
params = append(params, list(...))
|
params = append(params, list(...))
|
||||||
|
|
||||||
bst <- xgb.Booster(params, append(watchlist, dtrain))
|
bst <- list(handle = NULL,raw = NULL)
|
||||||
|
class(bst) <- 'xgb.Booster'
|
||||||
|
bst$handle <- xgb.Booster(params, append(watchlist, dtrain))
|
||||||
for (i in 1:nrounds) {
|
for (i in 1:nrounds) {
|
||||||
succ <- xgb.iter.update(bst, dtrain, i - 1, obj)
|
succ <- xgb.iter.update(bst$handle, dtrain, i - 1, obj)
|
||||||
if (length(watchlist) != 0) {
|
if (length(watchlist) != 0) {
|
||||||
msg <- xgb.iter.eval(bst, watchlist, i - 1, feval)
|
msg <- xgb.iter.eval(bst$handle, watchlist, i - 1, feval)
|
||||||
cat(paste(msg, "\n", sep=""))
|
cat(paste(msg, "\n", sep=""))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
bst$raw <- xgb.save.raw(bst$handle)
|
||||||
return(bst)
|
return(bst)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user