move stream to rabit part, support rabit on yarn
This commit is contained in:
@@ -12,7 +12,6 @@
|
||||
#include <limits>
|
||||
#include "../sync/sync.h"
|
||||
#include "../utils/io.h"
|
||||
#include "../utils/base64.h"
|
||||
#include "./objective.h"
|
||||
#include "./evaluation.h"
|
||||
#include "../gbm/gbm.h"
|
||||
@@ -178,44 +177,37 @@ class BoostLearner : public rabit::ISerializable {
|
||||
}
|
||||
// rabit load model from rabit checkpoint
|
||||
virtual void Load(rabit::IStream &fi) {
|
||||
RabitStreamAdapter fs(fi);
|
||||
// for row split, we should not keep pbuffer
|
||||
this->LoadModel(fs, distributed_mode != 2, false);
|
||||
this->LoadModel(fi, distributed_mode != 2, false);
|
||||
}
|
||||
// rabit save model to rabit checkpoint
|
||||
virtual void Save(rabit::IStream &fo) const {
|
||||
RabitStreamAdapter fs(fo);
|
||||
// for row split, we should not keep pbuffer
|
||||
this->SaveModel(fs, distributed_mode != 2);
|
||||
this->SaveModel(fo, distributed_mode != 2);
|
||||
}
|
||||
/*!
|
||||
* \brief load model from file
|
||||
* \param fname file name
|
||||
*/
|
||||
inline void LoadModel(const char *fname) {
|
||||
FILE *fp = utils::FopenCheck(fname, "rb");
|
||||
utils::FileStream fi(fp);
|
||||
utils::IStream *fi = rabit::io::CreateStream(fname, "r");
|
||||
std::string header; header.resize(4);
|
||||
// check header for different binary encode
|
||||
// can be base64 or binary
|
||||
if (fi.Read(&header[0], 4) != 0) {
|
||||
// base64 format
|
||||
if (header == "bs64") {
|
||||
utils::Base64InStream bsin(fp);
|
||||
bsin.InitPosition();
|
||||
this->LoadModel(bsin);
|
||||
fclose(fp);
|
||||
return;
|
||||
}
|
||||
if (header == "binf") {
|
||||
this->LoadModel(fi);
|
||||
fclose(fp);
|
||||
return;
|
||||
}
|
||||
utils::Check(fi->Read(&header[0], 4) != 0, "invalid model");
|
||||
// base64 format
|
||||
if (header == "bs64") {
|
||||
utils::Base64InStream bsin(fi);
|
||||
bsin.InitPosition();
|
||||
this->LoadModel(bsin);
|
||||
} else if (header == "binf") {
|
||||
this->LoadModel(*fi);
|
||||
} else {
|
||||
delete fi;
|
||||
fi = rabit::io::CreateStream(fname, "r");
|
||||
this->LoadModel(*fi);
|
||||
}
|
||||
fi.Seek(0);
|
||||
this->LoadModel(fi);
|
||||
fclose(fp);
|
||||
delete fi;
|
||||
}
|
||||
inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const {
|
||||
fo.Write(&mparam, sizeof(ModelParam));
|
||||
@@ -226,33 +218,20 @@ class BoostLearner : public rabit::ISerializable {
|
||||
/*!
|
||||
* \brief save model into file
|
||||
* \param fname file name
|
||||
* \param save_base64 whether save in base64 format
|
||||
*/
|
||||
inline void SaveModel(const char *fname) const {
|
||||
FILE *fp;
|
||||
bool use_stdout = false;;
|
||||
#ifndef XGBOOST_STRICT_CXX98_
|
||||
if (!strcmp(fname, "stdout")) {
|
||||
fp = stdout;
|
||||
use_stdout = true;
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
fp = utils::FopenCheck(fname, "wb");
|
||||
}
|
||||
utils::FileStream fo(fp);
|
||||
std::string header;
|
||||
if (save_base64 != 0|| use_stdout) {
|
||||
fo.Write("bs64\t", 5);
|
||||
utils::Base64OutStream bout(fp);
|
||||
inline void SaveModel(const char *fname, bool save_base64 = false) const {
|
||||
utils::IStream *fo = rabit::io::CreateStream(fname, "w");
|
||||
if (save_base64 != 0 || !strcmp(fname, "stdout")) {
|
||||
fo->Write("bs64\t", 5);
|
||||
utils::Base64OutStream bout(fo);
|
||||
this->SaveModel(bout);
|
||||
bout.Finish('\n');
|
||||
bout.Finish('\n');
|
||||
} else {
|
||||
fo.Write("binf", 4);
|
||||
this->SaveModel(fo);
|
||||
}
|
||||
if (!use_stdout) {
|
||||
fclose(fp);
|
||||
fo->Write("binf", 4);
|
||||
this->SaveModel(*fo);
|
||||
}
|
||||
delete fo;
|
||||
}
|
||||
/*!
|
||||
* \brief check if data matrix is ready to be used by training,
|
||||
@@ -512,23 +491,6 @@ class BoostLearner : public rabit::ISerializable {
|
||||
// data structure field
|
||||
/*! \brief the entries indicates that we have internal prediction cache */
|
||||
std::vector<CacheEntry> cache_;
|
||||
|
||||
private:
|
||||
// adapt rabit stream to utils stream
|
||||
struct RabitStreamAdapter : public utils::IStream {
|
||||
// rabit stream
|
||||
rabit::IStream &fs;
|
||||
// constructr
|
||||
RabitStreamAdapter(rabit::IStream &fs) : fs(fs) {}
|
||||
// destructor
|
||||
virtual ~RabitStreamAdapter(void){}
|
||||
virtual size_t Read(void *ptr, size_t size) {
|
||||
return fs.Read(ptr, size);
|
||||
}
|
||||
virtual void Write(const void *ptr, size_t size) {
|
||||
fs.Write(ptr, size);
|
||||
}
|
||||
};
|
||||
};
|
||||
} // namespace learner
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user