diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE index fab1546a2..ee8496d83 100644 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -13,6 +13,7 @@ export(xgb.model.dt.tree) export(xgb.plot.importance) export(xgb.plot.tree) export(xgb.save) +export(xgb.save.raw) export(xgb.train) export(xgboost) exportMethods(predict) diff --git a/R-package/R/utils.R b/R-package/R/utils.R index 412132891..b0c7f15ac 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -57,10 +57,13 @@ xgb.Booster <- function(params = list(), cachelist = list(), modelfile = NULL) { } } if (!is.null(modelfile)) { - if (typeof(modelfile) != "character") { - stop("xgb.Booster: modelfile must be character") + if (typeof(modelfile) == "character") { + .Call("XGBoosterLoadModel_R", handle, modelfile, PACKAGE = "xgboost") + } else if (typeof(modelfile) == "raw") { + .Call("XGBoosterLoadModelFromRaw_R", handle, modelfile, PACKAGE = "xgboost") + } else { + stop("xgb.Booster: modelfile must be character or raw vector") } - .Call("XGBoosterLoadModel_R", handle, modelfile, PACKAGE = "xgboost") } return(structure(handle, class = "xgb.Booster")) } diff --git a/R-package/R/xgb.save.raw.R b/R-package/R/xgb.save.raw.R new file mode 100644 index 000000000..11fc44470 --- /dev/null +++ b/R-package/R/xgb.save.raw.R @@ -0,0 +1,27 @@ +#' Save xgboost model to R's raw vector, +#' user can call xgb.load to load the model back from raw vector +#' +#' Save xgboost model from xgboost or xgb.train +#' +#' @param model the model object. +#' +#' @examples +#' data(agaricus.train, package='xgboost') +#' data(agaricus.test, package='xgboost') +#' train <- agaricus.train +#' test <- agaricus.test +#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2, +#' eta = 1, nround = 2,objective = "binary:logistic") +#' raw <- xgb.save(bst) +#' bst <- xgb.load(raw) +#' pred <- predict(bst, test$data) +#' @export +#' +xgb.save.raw <- function(model) { + if (class(model) == "xgb.Booster") { + raw <- .Call("XGBoosterModelToRaw_R", model, PACKAGE = "xgboost") + return(raw) + } + stop("xgb.raw: the input must be xgb.Booster. Use xgb.DMatrix.save to save + xgb.DMatrix object.") +} diff --git a/R-package/demo/basic_walkthrough.R b/R-package/demo/basic_walkthrough.R index 7e6914b31..25dd56612 100644 --- a/R-package/demo/basic_walkthrough.R +++ b/R-package/demo/basic_walkthrough.R @@ -58,6 +58,14 @@ pred2 <- predict(bst2, test$data) # pred2 should be identical to pred print(paste("sum(abs(pred2-pred))=", sum(abs(pred2-pred)))) +# save model to R's raw vector +raw = xgb.save.raw(bst) +# load binary model to R +bst3 <- xgb.load(raw) +pred3 <- predict(bst2, test$data) +# pred2 should be identical to pred +print(paste("sum(abs(pred3-pred))=", sum(abs(pred2-pred)))) + #----------------Advanced features -------------- # to use advanced features, we need to put data in xgb.DMatrix dtrain <- xgb.DMatrix(data = train$data, label=train$label) diff --git a/R-package/src/xgboost_R.cpp b/R-package/src/xgboost_R.cpp index aa17b30cc..9ebcec167 100644 --- a/R-package/src/xgboost_R.cpp +++ b/R-package/src/xgboost_R.cpp @@ -274,6 +274,23 @@ extern "C" { XGBoosterSaveModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname))); _WrapperEnd(); } + void XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) { + _WrapperBegin(); + XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle), + RAW(raw), + length(raw)); + _WrapperEnd(); + } + SEXP XGBoosterModelToRaw_R(SEXP handle) { + bst_ulong olen; + _WrapperBegin(); + const char *raw = XGBoosterGetModelRaw(R_ExternalPtrAddr(handle), &olen); + _WrapperEnd(); + SEXP ret = PROTECT(allocVector(RAWSXP, olen)); + memcpy(RAW(ret), raw, olen); + UNPROTECT(1); + return ret; + } SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats) { _WrapperBegin(); bst_ulong olen; diff --git a/R-package/src/xgboost_R.h b/R-package/src/xgboost_R.h index 1e7606dd7..a86e85ffa 100644 --- a/R-package/src/xgboost_R.h +++ b/R-package/src/xgboost_R.h @@ -127,6 +127,17 @@ extern "C" { * \param fname file name */ void XGBoosterSaveModel_R(SEXP handle, SEXP fname); + /*! + * \brief load model from raw array + * \param handle handle + */ + void XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw); + /*! + * \brief save model into R's raw array + * \param handle handle + * \return raw array + */ + SEXP XGBoosterModelToRaw_R(SEXP handle); /*! * \brief dump model into a string * \param handle handle diff --git a/src/learner/learner-inl.hpp b/src/learner/learner-inl.hpp index f7af5bfff..ee9153896 100644 --- a/src/learner/learner-inl.hpp +++ b/src/learner/learner-inl.hpp @@ -159,7 +159,9 @@ class BoostLearner : public rabit::ISerializable { * \param with_pbuffer whether to load with predict buffer * \param calc_num_feature whether call InitTrainer with calc_num_feature */ - inline void LoadModel(utils::IStream &fi, bool with_pbuffer = true, bool calc_num_feature = true) { + inline void LoadModel(utils::IStream &fi, + bool with_pbuffer = true, + bool calc_num_feature = true) { utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0, "BoostLearner: wrong model format"); utils::Check(fi.Read(&name_obj_), "BoostLearner: wrong model format"); @@ -192,8 +194,8 @@ class BoostLearner : public rabit::ISerializable { */ inline void LoadModel(const char *fname) { FILE *fp = utils::FopenCheck(fname, "rb"); - std::string header; header.resize(4); utils::FileStream fi(fp); + std::string header; header.resize(4); // check header for different binary encode // can be base64 or binary if (fi.Read(&header[0], 4) != 0) { diff --git a/wrapper/xgboost_wrapper.cpp b/wrapper/xgboost_wrapper.cpp index d744c3e22..2aa523494 100644 --- a/wrapper/xgboost_wrapper.cpp +++ b/wrapper/xgboost_wrapper.cpp @@ -57,6 +57,22 @@ class Booster: public learner::BoostLearner { learner::BoostLearner::LoadModel(fname); this->init_model = true; } + inline void LoadModelFromBuffer(const void *buf, size_t size) { + utils::MemoryFixSizeBuffer fs((void*)buf, size); + learner::BoostLearner::LoadModel(fs); + this->init_model = true; + } + inline const char *GetModelRaw(bst_ulong *out_len) { + model_str.resize(0); + utils::MemoryBufferStream fs(&model_str); + learner::BoostLearner::SaveModel(fs); + *out_len = static_cast(model_str.length()); + if (*out_len == 0) { + return NULL; + } else { + return &model_str[0]; + } + } inline const char** GetModelDump(const utils::FeatMap& fmap, bool with_stats, bst_ulong *len) { model_dump = this->DumpModel(fmap, with_stats); model_dump_cptr.resize(model_dump.size()); @@ -69,6 +85,8 @@ class Booster: public learner::BoostLearner { // temporal fields // temporal data to save evaluation dump std::string eval_str; + // temporal data to save model dump + std::string model_str; // temporal space to save model dump std::vector model_dump; std::vector model_dump_cptr; @@ -295,6 +313,12 @@ extern "C"{ void XGBoosterSaveModel(const void *handle, const char *fname) { static_cast(handle)->SaveModel(fname); } + void XGBoosterLoadModelFromBuffer(void *handle, const void *buf, bst_ulong len) { + static_cast(handle)->LoadModelFromBuffer(buf, len); + } + const char *XGBoosterGetModelRaw(void *handle, bst_ulong *out_len) { + return static_cast(handle)->GetModelRaw(out_len); + } const char** XGBoosterDumpModel(void *handle, const char *fmap, int with_stats, bst_ulong *len){ utils::FeatMap featmap; if (strlen(fmap) != 0) { diff --git a/wrapper/xgboost_wrapper.h b/wrapper/xgboost_wrapper.h index 82fedb9d6..f236ee5da 100644 --- a/wrapper/xgboost_wrapper.h +++ b/wrapper/xgboost_wrapper.h @@ -224,6 +224,21 @@ extern "C" { * \param fname file name */ XGB_DLL void XGBoosterSaveModel(const void *handle, const char *fname); + /*! + * \brief load model from in memory buffer + * \param handle handle + * \param buf pointer to the buffer + * \param len the length of the buffer + */ + XGB_DLL void XGBoosterLoadModelFromBuffer(void *handle, const void *buf, bst_ulong len); + /*! + * \brief save model into binary raw bytes, return header of the array + * user must copy the result out, before next xgboost call + * \param handle handle + * \param out_len the argument to hold the output length + * \return the pointer to the beginning of binary buffer + */ + XGB_DLL const char *XGBoosterGetModelRaw(void *handle, bst_ulong *out_len); /*! * \brief dump model, return array of strings representing model dump * \param handle handle