move stream to rabit part, support rabit on yarn

This commit is contained in:
tqchen
2015-03-09 14:43:46 -07:00
parent 9f7c6fe271
commit a8d5af39fd
14 changed files with 134 additions and 500 deletions

View File

@@ -12,7 +12,6 @@
#include <limits>
#include "../sync/sync.h"
#include "../utils/io.h"
#include "../utils/base64.h"
#include "./objective.h"
#include "./evaluation.h"
#include "../gbm/gbm.h"
@@ -178,44 +177,37 @@ class BoostLearner : public rabit::ISerializable {
}
// rabit load model from rabit checkpoint
virtual void Load(rabit::IStream &fi) {
RabitStreamAdapter fs(fi);
// for row split, we should not keep pbuffer
this->LoadModel(fs, distributed_mode != 2, false);
this->LoadModel(fi, distributed_mode != 2, false);
}
// rabit save model to rabit checkpoint
virtual void Save(rabit::IStream &fo) const {
RabitStreamAdapter fs(fo);
// for row split, we should not keep pbuffer
this->SaveModel(fs, distributed_mode != 2);
this->SaveModel(fo, distributed_mode != 2);
}
/*!
* \brief load model from file
* \param fname file name
*/
inline void LoadModel(const char *fname) {
FILE *fp = utils::FopenCheck(fname, "rb");
utils::FileStream fi(fp);
utils::IStream *fi = rabit::io::CreateStream(fname, "r");
std::string header; header.resize(4);
// check header for different binary encode
// can be base64 or binary
if (fi.Read(&header[0], 4) != 0) {
// base64 format
if (header == "bs64") {
utils::Base64InStream bsin(fp);
bsin.InitPosition();
this->LoadModel(bsin);
fclose(fp);
return;
}
if (header == "binf") {
this->LoadModel(fi);
fclose(fp);
return;
}
utils::Check(fi->Read(&header[0], 4) != 0, "invalid model");
// base64 format
if (header == "bs64") {
utils::Base64InStream bsin(fi);
bsin.InitPosition();
this->LoadModel(bsin);
} else if (header == "binf") {
this->LoadModel(*fi);
} else {
delete fi;
fi = rabit::io::CreateStream(fname, "r");
this->LoadModel(*fi);
}
fi.Seek(0);
this->LoadModel(fi);
fclose(fp);
delete fi;
}
inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const {
fo.Write(&mparam, sizeof(ModelParam));
@@ -226,33 +218,20 @@ class BoostLearner : public rabit::ISerializable {
/*!
* \brief save model into file
* \param fname file name
* \param save_base64 whether save in base64 format
*/
inline void SaveModel(const char *fname) const {
FILE *fp;
bool use_stdout = false;;
#ifndef XGBOOST_STRICT_CXX98_
if (!strcmp(fname, "stdout")) {
fp = stdout;
use_stdout = true;
} else
#endif
{
fp = utils::FopenCheck(fname, "wb");
}
utils::FileStream fo(fp);
std::string header;
if (save_base64 != 0|| use_stdout) {
fo.Write("bs64\t", 5);
utils::Base64OutStream bout(fp);
inline void SaveModel(const char *fname, bool save_base64 = false) const {
utils::IStream *fo = rabit::io::CreateStream(fname, "w");
if (save_base64 != 0 || !strcmp(fname, "stdout")) {
fo->Write("bs64\t", 5);
utils::Base64OutStream bout(fo);
this->SaveModel(bout);
bout.Finish('\n');
bout.Finish('\n');
} else {
fo.Write("binf", 4);
this->SaveModel(fo);
}
if (!use_stdout) {
fclose(fp);
fo->Write("binf", 4);
this->SaveModel(*fo);
}
delete fo;
}
/*!
* \brief check if data matrix is ready to be used by training,
@@ -512,23 +491,6 @@ class BoostLearner : public rabit::ISerializable {
// data structure field
/*! \brief the entries indicates that we have internal prediction cache */
std::vector<CacheEntry> cache_;
private:
// adapt rabit stream to utils stream
struct RabitStreamAdapter : public utils::IStream {
// rabit stream
rabit::IStream &fs;
// constructr
RabitStreamAdapter(rabit::IStream &fs) : fs(fs) {}
// destructor
virtual ~RabitStreamAdapter(void){}
virtual size_t Read(void *ptr, size_t size) {
return fs.Read(ptr, size);
}
virtual void Write(const void *ptr, size_t size) {
fs.Write(ptr, size);
}
};
};
} // namespace learner
} // namespace xgboost