diff --git a/R-package/R/predict.xgb.Booster.R b/R-package/R/predict.xgb.Booster.R index b1c3c10ca..52c40df9b 100644 --- a/R-package/R/predict.xgb.Booster.R +++ b/R-package/R/predict.xgb.Booster.R @@ -36,12 +36,7 @@ setMethod("predict", signature = "xgb.Booster", if (class(object) != "xgb.Booster"){ stop("predict: model in prediction must be of class xgb.Booster") } else { - if (is.null(object$handle)) { - object$handle <- xgb.load(object$raw) - } else { - if (is.null(object$raw)) - object$raw <- xgb.save.raw(object$handle) - } + object <- xgb.Booster.check(object, saveraw = FALSE) } if (class(newdata) != "xgb.DMatrix") { if (is.null(missing)) { diff --git a/R-package/R/predict.xgb.Booster.handle.R b/R-package/R/predict.xgb.Booster.handle.R index 05cbf891e..a38aeb64e 100644 --- a/R-package/R/predict.xgb.Booster.handle.R +++ b/R-package/R/predict.xgb.Booster.handle.R @@ -6,9 +6,9 @@ setMethod("predict", signature = "xgb.Booster.handle", stop("predict: model in prediction must be of class xgb.Booster.handle") } - bst <- list(handle = object,raw = NULL) - class(bst) <- 'xgb.Booster' - bst$raw <- xgb.save.raw(bst$handle) + bst <- xgb.handleToBooster(object) + # Avoid save a handle without update + # bst$raw <- xgb.save.raw(object) ret = predict(bst, ...) return(ret) diff --git a/R-package/R/utils.R b/R-package/R/utils.R index 5093382d4..bff6dd0e8 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -68,6 +68,26 @@ xgb.Booster <- function(params = list(), cachelist = list(), modelfile = NULL) { return(structure(handle, class = "xgb.Booster.handle")) } +# convert xgb.Booster.handle to xgb.Booster +xgb.handleToBooster <- function(handle) +{ + bst <- list(handle = handle, raw = NULL) + class(bst) <- "xgb.Booster" + return(bst) +} + +# Check whether an xgb.Booster object is complete +xgb.Booster.check <- function(bst, saveraw = TRUE) +{ + if (is.null(bst$handle)) { + bst$handle <- xgb.load(bst$raw) + } else { + if (is.null(bst$raw) && saveraw) + bst$raw <- xgb.save.raw(bst$handle) + } + return(bst) +} + ## ----the following are low level iteratively function, not needed if ## you do not want to use them --------------------------------------- # get dmatrix from data, label diff --git a/R-package/R/xgb.dump.R b/R-package/R/xgb.dump.R index 1f73eed2e..fa5fe4149 100644 --- a/R-package/R/xgb.dump.R +++ b/R-package/R/xgb.dump.R @@ -41,12 +41,7 @@ xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with.stats=FALSE) { if (class(model) != "xgb.Booster") { stop("model: argument must be type xgb.Booster") } else { - if (is.null(model$handle)) { - model$handle <- xgb.load(model$raw) - } else { - if (is.null(model$raw)) - model$raw <- xgb.save.raw(model$handle) - } + model <- xgb.Booster.check(model) } if (!(class(fname) %in% c("character", "NULL") && length(fname) <= 1)) { stop("fname: argument must be type character (when provided)") diff --git a/R-package/R/xgb.load.R b/R-package/R/xgb.load.R index 264176952..33d440530 100644 --- a/R-package/R/xgb.load.R +++ b/R-package/R/xgb.load.R @@ -20,9 +20,8 @@ xgb.load <- function(modelfile) { if (is.null(modelfile)) stop("xgb.load: modelfile cannot be NULL") - bst <- list(handle = NULL,raw = NULL) - class(bst) <- 'xgb.Booster' - bst$handle <- xgb.Booster(modelfile = modelfile) - bst$raw <- xgb.save.raw(bst$handle) + handle <- xgb.Booster(modelfile = modelfile) + bst <- xgb.handleToBooster(handle) + bst <- xgb.Booster.check(bst) return(bst) } diff --git a/R-package/R/xgb.save.R b/R-package/R/xgb.save.R index 0fecddfb5..59c5d2ecd 100644 --- a/R-package/R/xgb.save.R +++ b/R-package/R/xgb.save.R @@ -22,9 +22,7 @@ xgb.save <- function(model, fname) { stop("xgb.save: fname must be character") } if (class(model) == "xgb.Booster") { - if (is.null(model$handle)) { - model$handle <- xgb.load(model$raw) - } + model <- xgb.Booster.check(model) .Call("XGBoosterSaveModel_R", model$handle, fname, PACKAGE = "xgboost") return(TRUE) } diff --git a/R-package/R/xgb.train.R b/R-package/R/xgb.train.R index c6d29e6e3..250ba2fbf 100644 --- a/R-package/R/xgb.train.R +++ b/R-package/R/xgb.train.R @@ -86,9 +86,8 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(), } params = append(params, list(...)) - bst <- list(handle = NULL,raw = NULL) - class(bst) <- 'xgb.Booster' - bst$handle <- xgb.Booster(params, append(watchlist, dtrain)) + handle <- xgb.Booster(params, append(watchlist, dtrain)) + bst <- xgb.handleToBooster(handle) for (i in 1:nrounds) { succ <- xgb.iter.update(bst$handle, dtrain, i - 1, obj) if (length(watchlist) != 0) { @@ -96,6 +95,6 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(), cat(paste(msg, "\n", sep="")) } } - bst$raw <- xgb.save.raw(bst$handle) + bst <- xgb.Booster.check(bst) return(bst) }