From 5b611c355e61546b7c70d4f6a86b387e50a2b17e Mon Sep 17 00:00:00 2001 From: hetong007 Date: Mon, 9 Feb 2015 15:51:24 -0800 Subject: [PATCH] add handle and raw structure to xgb.Booster --- R-package/R/predict.xgb.Booster.R | 17 +++++++++++++++-- R-package/R/utils.R | 8 ++++---- R-package/R/xgb.dump.R | 9 ++++++++- R-package/R/xgb.load.R | 6 +++++- R-package/R/xgb.save.R | 5 ++++- R-package/R/xgb.save.raw.R | 8 ++++---- R-package/R/xgb.train.R | 9 ++++++--- 7 files changed, 46 insertions(+), 16 deletions(-) diff --git a/R-package/R/predict.xgb.Booster.R b/R-package/R/predict.xgb.Booster.R index 1e458e708..033bfab84 100644 --- a/R-package/R/predict.xgb.Booster.R +++ b/R-package/R/predict.xgb.Booster.R @@ -1,4 +1,7 @@ -setClass("xgb.Booster") +setClass("xgb.Booster.handle") +setClass("xgb.Booster", + slots = c(handle = "xgb.Booster.handle", + raw = "raw")) #' Predict method for eXtreme Gradient Boosting model #' @@ -30,6 +33,16 @@ setClass("xgb.Booster") setMethod("predict", signature = "xgb.Booster", definition = function(object, newdata, missing = NULL, outputmargin = FALSE, ntreelimit = NULL, predleaf = FALSE) { + if (class(object) != "xgb.Booster"){ + stop("predict: model in prediction must be of class xgb.Booster") + } else { + if (is.null(object$handle)) { + object$handle <- xgb.load(object$raw) + } else { + if (is.null(object$raw)) + object$raw <- xgb.save.raw(object$handle) + } + } if (class(newdata) != "xgb.DMatrix") { if (is.null(missing)) { newdata <- xgb.DMatrix(newdata) @@ -51,7 +64,7 @@ setMethod("predict", signature = "xgb.Booster", if (predleaf) { option <- option + 2 } - ret <- .Call("XGBoosterPredict_R", object, newdata, as.integer(option), + ret <- .Call("XGBoosterPredict_R", object$handle, newdata, as.integer(option), as.integer(ntreelimit), PACKAGE = "xgboost") if (predleaf){ len <- getinfo(newdata, "nrow") diff --git a/R-package/R/utils.R b/R-package/R/utils.R index b0c7f15ac..fb3f59957 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -65,7 +65,7 @@ xgb.Booster <- function(params = list(), cachelist = list(), modelfile = NULL) { stop("xgb.Booster: modelfile must be character or raw vector") } } - return(structure(handle, class = "xgb.Booster")) + return(structure(handle, class = "xgb.Booster.handle")) } ## ----the following are low level iteratively function, not needed if @@ -102,7 +102,7 @@ xgb.numrow <- function(dmat) { } # iteratively update booster with customized statistics xgb.iter.boost <- function(booster, dtrain, gpair) { - if (class(booster) != "xgb.Booster") { + if (class(booster) != "xgb.Booster.handle") { stop("xgb.iter.update: first argument must be type xgb.Booster") } if (class(dtrain) != "xgb.DMatrix") { @@ -115,7 +115,7 @@ xgb.iter.boost <- function(booster, dtrain, gpair) { # iteratively update booster with dtrain xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) { - if (class(booster) != "xgb.Booster") { + if (class(booster) != "xgb.Booster.handle") { stop("xgb.iter.update: first argument must be type xgb.Booster") } if (class(dtrain) != "xgb.DMatrix") { @@ -135,7 +135,7 @@ xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) { # iteratively evaluate one iteration xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL, prediction = FALSE) { - if (class(booster) != "xgb.Booster") { + if (class(booster) != "xgb.Booster.handle") { stop("xgb.eval: first argument must be type xgb.Booster") } if (typeof(watchlist) != "list") { diff --git a/R-package/R/xgb.dump.R b/R-package/R/xgb.dump.R index edeb03b5f..1f73eed2e 100644 --- a/R-package/R/xgb.dump.R +++ b/R-package/R/xgb.dump.R @@ -40,6 +40,13 @@ xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with.stats=FALSE) { if (class(model) != "xgb.Booster") { stop("model: argument must be type xgb.Booster") + } else { + if (is.null(model$handle)) { + model$handle <- xgb.load(model$raw) + } else { + if (is.null(model$raw)) + model$raw <- xgb.save.raw(model$handle) + } } if (!(class(fname) %in% c("character", "NULL") && length(fname) <= 1)) { stop("fname: argument must be type character (when provided)") @@ -48,7 +55,7 @@ xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with.stats=FALSE) { stop("fmap: argument must be type character (when provided)") } - longString <- .Call("XGBoosterDumpModel_R", model, fmap, as.integer(with.stats), PACKAGE = "xgboost") + longString <- .Call("XGBoosterDumpModel_R", model$handle, fmap, as.integer(with.stats), PACKAGE = "xgboost") dt <- fread(paste(longString, collapse = ""), sep = "\n", header = F) diff --git a/R-package/R/xgb.load.R b/R-package/R/xgb.load.R index af87e2b3c..87247b4a9 100644 --- a/R-package/R/xgb.load.R +++ b/R-package/R/xgb.load.R @@ -19,5 +19,9 @@ xgb.load <- function(modelfile) { if (is.null(modelfile)) stop("xgb.load: modelfile cannot be NULL") - xgb.Booster(modelfile = modelfile) + bst <- list(handle = NULL,raw = NULL) + class(bst) <- 'xgb.Booster' + bst$handle <- xgb.Booster(modelfile = modelfile) + bst$raw <- xgb.save.raw(bst$handle) + return(bst) } diff --git a/R-package/R/xgb.save.R b/R-package/R/xgb.save.R index 2a250a9af..0fecddfb5 100644 --- a/R-package/R/xgb.save.R +++ b/R-package/R/xgb.save.R @@ -22,7 +22,10 @@ xgb.save <- function(model, fname) { stop("xgb.save: fname must be character") } if (class(model) == "xgb.Booster") { - .Call("XGBoosterSaveModel_R", model, fname, PACKAGE = "xgboost") + if (is.null(model$handle)) { + model$handle <- xgb.load(model$raw) + } + .Call("XGBoosterSaveModel_R", model$handle, fname, PACKAGE = "xgboost") return(TRUE) } stop("xgb.save: the input must be xgb.Booster. Use xgb.DMatrix.save to save diff --git a/R-package/R/xgb.save.raw.R b/R-package/R/xgb.save.raw.R index d8ed6f526..91f7075bd 100644 --- a/R-package/R/xgb.save.raw.R +++ b/R-package/R/xgb.save.raw.R @@ -17,11 +17,11 @@ #' pred <- predict(bst, test$data) #' @export #' -xgb.save.raw <- function(model) { - if (class(model) == "xgb.Booster") { - raw <- .Call("XGBoosterModelToRaw_R", model, PACKAGE = "xgboost") +xgb.save.raw <- function(handle) { + if (class(handle) == "xgb.Booster.handle") { + raw <- .Call("XGBoosterModelToRaw_R", handle, PACKAGE = "xgboost") return(raw) } - stop("xgb.raw: the input must be xgb.Booster. Use xgb.DMatrix.save to save + stop("xgb.raw: the input must be xgb.Booster.handle. Use xgb.DMatrix.save to save xgb.DMatrix object.") } diff --git a/R-package/R/xgb.train.R b/R-package/R/xgb.train.R index 06c39d76c..c6d29e6e3 100644 --- a/R-package/R/xgb.train.R +++ b/R-package/R/xgb.train.R @@ -86,13 +86,16 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(), } params = append(params, list(...)) - bst <- xgb.Booster(params, append(watchlist, dtrain)) + bst <- list(handle = NULL,raw = NULL) + class(bst) <- 'xgb.Booster' + bst$handle <- xgb.Booster(params, append(watchlist, dtrain)) for (i in 1:nrounds) { - succ <- xgb.iter.update(bst, dtrain, i - 1, obj) + succ <- xgb.iter.update(bst$handle, dtrain, i - 1, obj) if (length(watchlist) != 0) { - msg <- xgb.iter.eval(bst, watchlist, i - 1, feval) + msg <- xgb.iter.eval(bst$handle, watchlist, i - 1, feval) cat(paste(msg, "\n", sep="")) } } + bst$raw <- xgb.save.raw(bst$handle) return(bst) }