add saveload to raw

This commit is contained in:
tqchen 2015-02-01 21:17:37 -08:00
parent 6e91846c55
commit dc3003cefd
9 changed files with 113 additions and 5 deletions

View File

@ -13,6 +13,7 @@ export(xgb.model.dt.tree)
export(xgb.plot.importance)
export(xgb.plot.tree)
export(xgb.save)
export(xgb.save.raw)
export(xgb.train)
export(xgboost)
exportMethods(predict)

View File

@ -57,10 +57,13 @@ xgb.Booster <- function(params = list(), cachelist = list(), modelfile = NULL) {
}
}
if (!is.null(modelfile)) {
if (typeof(modelfile) != "character") {
stop("xgb.Booster: modelfile must be character")
if (typeof(modelfile) == "character") {
.Call("XGBoosterLoadModel_R", handle, modelfile, PACKAGE = "xgboost")
} else if (typeof(modelfile) == "raw") {
.Call("XGBoosterLoadModelFromRaw_R", handle, modelfile, PACKAGE = "xgboost")
} else {
stop("xgb.Booster: modelfile must be character or raw vector")
}
.Call("XGBoosterLoadModel_R", handle, modelfile, PACKAGE = "xgboost")
}
return(structure(handle, class = "xgb.Booster"))
}

View File

@ -0,0 +1,27 @@
#' Save xgboost model to R's raw vector,
#' user can call xgb.load to load the model back from raw vector
#'
#' Save xgboost model from xgboost or xgb.train
#'
#' @param model the model object.
#'
#' @examples
#' data(agaricus.train, package='xgboost')
#' data(agaricus.test, package='xgboost')
#' train <- agaricus.train
#' test <- agaricus.test
#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
#' eta = 1, nround = 2,objective = "binary:logistic")
#' raw <- xgb.save(bst)
#' bst <- xgb.load(raw)
#' pred <- predict(bst, test$data)
#' @export
#'
xgb.save.raw <- function(model) {
if (class(model) == "xgb.Booster") {
raw <- .Call("XGBoosterModelToRaw_R", model, PACKAGE = "xgboost")
return(raw)
}
stop("xgb.raw: the input must be xgb.Booster. Use xgb.DMatrix.save to save
xgb.DMatrix object.")
}

View File

@ -58,6 +58,14 @@ pred2 <- predict(bst2, test$data)
# pred2 should be identical to pred
print(paste("sum(abs(pred2-pred))=", sum(abs(pred2-pred))))
# save model to R's raw vector
raw = xgb.save.raw(bst)
# load binary model to R
bst3 <- xgb.load(raw)
pred3 <- predict(bst2, test$data)
# pred2 should be identical to pred
print(paste("sum(abs(pred3-pred))=", sum(abs(pred2-pred))))
#----------------Advanced features --------------
# to use advanced features, we need to put data in xgb.DMatrix
dtrain <- xgb.DMatrix(data = train$data, label=train$label)

View File

@ -274,6 +274,23 @@ extern "C" {
XGBoosterSaveModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname)));
_WrapperEnd();
}
void XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
_WrapperBegin();
XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle),
RAW(raw),
length(raw));
_WrapperEnd();
}
SEXP XGBoosterModelToRaw_R(SEXP handle) {
bst_ulong olen;
_WrapperBegin();
const char *raw = XGBoosterGetModelRaw(R_ExternalPtrAddr(handle), &olen);
_WrapperEnd();
SEXP ret = PROTECT(allocVector(RAWSXP, olen));
memcpy(RAW(ret), raw, olen);
UNPROTECT(1);
return ret;
}
SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats) {
_WrapperBegin();
bst_ulong olen;

View File

@ -127,6 +127,17 @@ extern "C" {
* \param fname file name
*/
void XGBoosterSaveModel_R(SEXP handle, SEXP fname);
/*!
* \brief load model from raw array
* \param handle handle
*/
void XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw);
/*!
* \brief save model into R's raw array
* \param handle handle
* \return raw array
*/
SEXP XGBoosterModelToRaw_R(SEXP handle);
/*!
* \brief dump model into a string
* \param handle handle

View File

@ -159,7 +159,9 @@ class BoostLearner : public rabit::ISerializable {
* \param with_pbuffer whether to load with predict buffer
* \param calc_num_feature whether call InitTrainer with calc_num_feature
*/
inline void LoadModel(utils::IStream &fi, bool with_pbuffer = true, bool calc_num_feature = true) {
inline void LoadModel(utils::IStream &fi,
bool with_pbuffer = true,
bool calc_num_feature = true) {
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
"BoostLearner: wrong model format");
utils::Check(fi.Read(&name_obj_), "BoostLearner: wrong model format");
@ -192,8 +194,8 @@ class BoostLearner : public rabit::ISerializable {
*/
inline void LoadModel(const char *fname) {
FILE *fp = utils::FopenCheck(fname, "rb");
std::string header; header.resize(4);
utils::FileStream fi(fp);
std::string header; header.resize(4);
// check header for different binary encode
// can be base64 or binary
if (fi.Read(&header[0], 4) != 0) {

View File

@ -57,6 +57,22 @@ class Booster: public learner::BoostLearner {
learner::BoostLearner::LoadModel(fname);
this->init_model = true;
}
inline void LoadModelFromBuffer(const void *buf, size_t size) {
utils::MemoryFixSizeBuffer fs((void*)buf, size);
learner::BoostLearner::LoadModel(fs);
this->init_model = true;
}
inline const char *GetModelRaw(bst_ulong *out_len) {
model_str.resize(0);
utils::MemoryBufferStream fs(&model_str);
learner::BoostLearner::SaveModel(fs);
*out_len = static_cast<bst_ulong>(model_str.length());
if (*out_len == 0) {
return NULL;
} else {
return &model_str[0];
}
}
inline const char** GetModelDump(const utils::FeatMap& fmap, bool with_stats, bst_ulong *len) {
model_dump = this->DumpModel(fmap, with_stats);
model_dump_cptr.resize(model_dump.size());
@ -69,6 +85,8 @@ class Booster: public learner::BoostLearner {
// temporal fields
// temporal data to save evaluation dump
std::string eval_str;
// temporal data to save model dump
std::string model_str;
// temporal space to save model dump
std::vector<std::string> model_dump;
std::vector<const char*> model_dump_cptr;
@ -295,6 +313,12 @@ extern "C"{
void XGBoosterSaveModel(const void *handle, const char *fname) {
static_cast<const Booster*>(handle)->SaveModel(fname);
}
void XGBoosterLoadModelFromBuffer(void *handle, const void *buf, bst_ulong len) {
static_cast<Booster*>(handle)->LoadModelFromBuffer(buf, len);
}
const char *XGBoosterGetModelRaw(void *handle, bst_ulong *out_len) {
return static_cast<Booster*>(handle)->GetModelRaw(out_len);
}
const char** XGBoosterDumpModel(void *handle, const char *fmap, int with_stats, bst_ulong *len){
utils::FeatMap featmap;
if (strlen(fmap) != 0) {

View File

@ -224,6 +224,21 @@ extern "C" {
* \param fname file name
*/
XGB_DLL void XGBoosterSaveModel(const void *handle, const char *fname);
/*!
* \brief load model from in memory buffer
* \param handle handle
* \param buf pointer to the buffer
* \param len the length of the buffer
*/
XGB_DLL void XGBoosterLoadModelFromBuffer(void *handle, const void *buf, bst_ulong len);
/*!
* \brief save model into binary raw bytes, return header of the array
* user must copy the result out, before next xgboost call
* \param handle handle
* \param out_len the argument to hold the output length
* \return the pointer to the beginning of binary buffer
*/
XGB_DLL const char *XGBoosterGetModelRaw(void *handle, bst_ulong *out_len);
/*!
* \brief dump model, return array of strings representing model dump
* \param handle handle