diff --git a/R-package/R/predict.xgb.Booster.R b/R-package/R/predict.xgb.Booster.R index 033bfab84..b1c3c10ca 100644 --- a/R-package/R/predict.xgb.Booster.R +++ b/R-package/R/predict.xgb.Booster.R @@ -77,4 +77,4 @@ setMethod("predict", signature = "xgb.Booster", } return(ret) }) - + diff --git a/R-package/R/predict.xgb.Booster.handle.R b/R-package/R/predict.xgb.Booster.handle.R new file mode 100644 index 000000000..05cbf891e --- /dev/null +++ b/R-package/R/predict.xgb.Booster.handle.R @@ -0,0 +1,16 @@ +setClass("xgb.Booster.handle") + +setMethod("predict", signature = "xgb.Booster.handle", + definition = function(object, ...) { + if (class(object) != "xgb.Booster.handle"){ + stop("predict: model in prediction must be of class xgb.Booster.handle") + } + + bst <- list(handle = object,raw = NULL) + class(bst) <- 'xgb.Booster' + bst$raw <- xgb.save.raw(bst$handle) + + ret = predict(bst, ...) + return(ret) +}) + diff --git a/R-package/R/utils.R b/R-package/R/utils.R index bcbde36d1..5093382d4 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -102,42 +102,28 @@ xgb.numrow <- function(dmat) { } # iteratively update booster with customized statistics xgb.iter.boost <- function(booster, dtrain, gpair) { - if (class(booster) != "xgb.Booster") { - stop("xgb.iter.update: first argument must be type xgb.Booster") - } else { - if (is.null(booster$handle)) { - booster$handle <- xgb.load(booster$raw) - } else { - if (is.null(booster$raw)) - booster$raw <- xgb.save.raw(booster$handle) - } + if (class(booster) != "xgb.Booster.handle") { + stop("xgb.iter.update: first argument must be type xgb.Booster.handle") } if (class(dtrain) != "xgb.DMatrix") { stop("xgb.iter.update: second argument must be type xgb.DMatrix") } - .Call("XGBoosterBoostOneIter_R", booster$handle, dtrain, gpair$grad, gpair$hess, + .Call("XGBoosterBoostOneIter_R", booster, dtrain, gpair$grad, gpair$hess, PACKAGE = "xgboost") return(TRUE) } # iteratively update booster with dtrain xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) { - if (class(booster) != "xgb.Booster") { - stop("xgb.iter.update: first argument must be type xgb.Booster") - } else { - if (is.null(booster$handle)) { - booster$handle <- xgb.load(booster$raw) - } else { - if (is.null(booster$raw)) - booster$raw <- xgb.save.raw(booster$handle) - } + if (class(booster) != "xgb.Booster.handle") { + stop("xgb.iter.update: first argument must be type xgb.Booster.handle") } if (class(dtrain) != "xgb.DMatrix") { stop("xgb.iter.update: second argument must be type xgb.DMatrix") } if (is.null(obj)) { - .Call("XGBoosterUpdateOneIter_R", booster$handle, as.integer(iter), dtrain, + .Call("XGBoosterUpdateOneIter_R", booster, as.integer(iter), dtrain, PACKAGE = "xgboost") } else { pred <- predict(booster, dtrain) diff --git a/R-package/R/xgb.load.R b/R-package/R/xgb.load.R index 87247b4a9..264176952 100644 --- a/R-package/R/xgb.load.R +++ b/R-package/R/xgb.load.R @@ -19,6 +19,7 @@ xgb.load <- function(modelfile) { if (is.null(modelfile)) stop("xgb.load: modelfile cannot be NULL") + bst <- list(handle = NULL,raw = NULL) class(bst) <- 'xgb.Booster' bst$handle <- xgb.Booster(modelfile = modelfile)