diff --git a/R-package/R/utils.R b/R-package/R/utils.R index fb3f59957..bcbde36d1 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -102,28 +102,42 @@ xgb.numrow <- function(dmat) { } # iteratively update booster with customized statistics xgb.iter.boost <- function(booster, dtrain, gpair) { - if (class(booster) != "xgb.Booster.handle") { + if (class(booster) != "xgb.Booster") { stop("xgb.iter.update: first argument must be type xgb.Booster") + } else { + if (is.null(booster$handle)) { + booster$handle <- xgb.load(booster$raw) + } else { + if (is.null(booster$raw)) + booster$raw <- xgb.save.raw(booster$handle) + } } if (class(dtrain) != "xgb.DMatrix") { stop("xgb.iter.update: second argument must be type xgb.DMatrix") } - .Call("XGBoosterBoostOneIter_R", booster, dtrain, gpair$grad, gpair$hess, + .Call("XGBoosterBoostOneIter_R", booster$handle, dtrain, gpair$grad, gpair$hess, PACKAGE = "xgboost") return(TRUE) } # iteratively update booster with dtrain xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) { - if (class(booster) != "xgb.Booster.handle") { + if (class(booster) != "xgb.Booster") { stop("xgb.iter.update: first argument must be type xgb.Booster") + } else { + if (is.null(booster$handle)) { + booster$handle <- xgb.load(booster$raw) + } else { + if (is.null(booster$raw)) + booster$raw <- xgb.save.raw(booster$handle) + } } if (class(dtrain) != "xgb.DMatrix") { stop("xgb.iter.update: second argument must be type xgb.DMatrix") } if (is.null(obj)) { - .Call("XGBoosterUpdateOneIter_R", booster, as.integer(iter), dtrain, + .Call("XGBoosterUpdateOneIter_R", booster$handle, as.integer(iter), dtrain, PACKAGE = "xgboost") } else { pred <- predict(booster, dtrain)