move stream to rabit part, support rabit on yarn
This commit is contained in:
parent
9f7c6fe271
commit
a8d5af39fd
7
Makefile
7
Makefile
@ -16,6 +16,13 @@ ifeq ($(cxx11),1)
|
|||||||
else
|
else
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
ifeq ($(hdfs),1)
|
||||||
|
CFLAGS+= -DRABIT_USE_HDFS=1 -I$(HADOOP_HDFS_HOME)/include -I$(JAVA_HOME)/include
|
||||||
|
LDFLAGS+= -L$(HADOOP_HDFS_HOME)/lib/native -L$(JAVA_HOME)/jre/lib/amd64/server -lhdfs -ljvm
|
||||||
|
else
|
||||||
|
CFLAGS+= -DRABIT_USE_HDFS=0
|
||||||
|
endif
|
||||||
|
|
||||||
# specify tensor path
|
# specify tensor path
|
||||||
BIN = xgboost
|
BIN = xgboost
|
||||||
MOCKBIN = xgboost.mock
|
MOCKBIN = xgboost.mock
|
||||||
|
|||||||
@ -25,10 +25,10 @@ save_period = 0
|
|||||||
# eval[test] = "agaricus.txt.test"
|
# eval[test] = "agaricus.txt.test"
|
||||||
|
|
||||||
# Plz donot modify the following parameters
|
# Plz donot modify the following parameters
|
||||||
# The path of training data
|
# The path of training data, with prefix hdfs
|
||||||
data = stdin
|
#data = hdfs:/data/
|
||||||
# The path of model file
|
# The path of model file
|
||||||
model_out = stdout
|
#model_out =
|
||||||
# split pattern of xgboost
|
# split pattern of xgboost
|
||||||
dsplit = row
|
dsplit = row
|
||||||
# evaluate on training data as well each round
|
# evaluate on training data as well each round
|
||||||
|
|||||||
@ -8,11 +8,16 @@ fi
|
|||||||
# put the local training file to HDFS
|
# put the local training file to HDFS
|
||||||
hadoop fs -mkdir $3/data
|
hadoop fs -mkdir $3/data
|
||||||
hadoop fs -put ../../demo/data/agaricus.txt.train $3/data
|
hadoop fs -put ../../demo/data/agaricus.txt.train $3/data
|
||||||
|
hadoop fs -put ../../demo/data/agaricus.txt.test $3/data
|
||||||
|
|
||||||
../../subtree/rabit/tracker/rabit_hadoop.py -n $1 -nt $2 -i $3/data/agaricus.txt.train -o $3/mushroom.final.model ../../xgboost mushroom.hadoop.conf nthread=$2
|
# running rabit, pass address in hdfs
|
||||||
|
../../subtree/rabit/tracker/rabit_yarn.py -n $1 --vcores $2 ../../xgboost mushroom.hadoop.conf nthread=$2\
|
||||||
|
data=hdfs://$3/data/agaricus.txt.train\
|
||||||
|
eval[test]=hdfs://$3/data/agaricus.txt.test\
|
||||||
|
model_out=hdfs://$3/mushroom.final.model
|
||||||
|
|
||||||
# get the final model file
|
# get the final model file
|
||||||
hadoop fs -get $3/mushroom.final.model/part-00000 ./final.model
|
hadoop fs -get $3/mushroom.final.model final.model
|
||||||
|
|
||||||
# output prediction task=pred
|
# output prediction task=pred
|
||||||
../../xgboost mushroom.hadoop.conf task=pred model_in=final.model test:data=../../demo/data/agaricus.txt.test
|
../../xgboost mushroom.hadoop.conf task=pred model_in=final.model test:data=../../demo/data/agaricus.txt.test
|
||||||
|
|||||||
@ -14,10 +14,11 @@
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace io {
|
namespace io {
|
||||||
DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) {
|
DataMatrix* LoadDataMatrix(const char *fname, bool silent,
|
||||||
if (!std::strcmp(fname, "stdin")) {
|
bool savebuffer, bool loadsplit) {
|
||||||
|
if (!std::strcmp(fname, "stdin") || loadsplit) {
|
||||||
DMatrixSimple *dmat = new DMatrixSimple();
|
DMatrixSimple *dmat = new DMatrixSimple();
|
||||||
dmat->LoadText(fname, silent);
|
dmat->LoadText(fname, silent, loadsplit);
|
||||||
return dmat;
|
return dmat;
|
||||||
}
|
}
|
||||||
int magic;
|
int magic;
|
||||||
|
|||||||
@ -19,9 +19,14 @@ typedef learner::DMatrix DataMatrix;
|
|||||||
* \param fname file name to be loaded
|
* \param fname file name to be loaded
|
||||||
* \param silent whether print message during loading
|
* \param silent whether print message during loading
|
||||||
* \param savebuffer whether temporal buffer the file if the file is in text format
|
* \param savebuffer whether temporal buffer the file if the file is in text format
|
||||||
|
* \param loadsplit whether we only load a split of input files
|
||||||
|
* such that each worker node get a split of the data
|
||||||
* \return a loaded DMatrix
|
* \return a loaded DMatrix
|
||||||
*/
|
*/
|
||||||
DataMatrix* LoadDataMatrix(const char *fname, bool silent = false, bool savebuffer = true);
|
DataMatrix* LoadDataMatrix(const char *fname,
|
||||||
|
bool silent,
|
||||||
|
bool savebuffer,
|
||||||
|
bool loadsplit);
|
||||||
/*!
|
/*!
|
||||||
* \brief save DataMatrix into stream,
|
* \brief save DataMatrix into stream,
|
||||||
* note: the saved dmatrix format may not be in exactly same as input
|
* note: the saved dmatrix format may not be in exactly same as input
|
||||||
|
|||||||
@ -11,12 +11,14 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <sstream>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "../data.h"
|
#include "../data.h"
|
||||||
#include "../utils/utils.h"
|
#include "../utils/utils.h"
|
||||||
#include "../learner/dmatrix.h"
|
#include "../learner/dmatrix.h"
|
||||||
#include "./io.h"
|
#include "./io.h"
|
||||||
#include "./simple_fmatrix-inl.hpp"
|
#include "./simple_fmatrix-inl.hpp"
|
||||||
|
#include "../sync/sync.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace io {
|
namespace io {
|
||||||
@ -77,63 +79,59 @@ class DMatrixSimple : public DataMatrix {
|
|||||||
return row_ptr_.size() - 2;
|
return row_ptr_.size() - 2;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief load from text file
|
* \brief load split of input, used in distributed mode
|
||||||
* \param fname name of text data
|
* \param uri the uri of input
|
||||||
|
* \param loadsplit whether loadsplit of data or all the data
|
||||||
* \param silent whether print information or not
|
* \param silent whether print information or not
|
||||||
*/
|
*/
|
||||||
inline void LoadText(const char* fname, bool silent = false) {
|
inline void LoadText(const char *uri, bool silent = false, bool loadsplit = false) {
|
||||||
using namespace std;
|
int rank = 0, npart = 1;
|
||||||
|
if (loadsplit) {
|
||||||
|
rank = rabit::GetRank();
|
||||||
|
npart = rabit::GetWorldSize();
|
||||||
|
}
|
||||||
|
rabit::io::InputSplit *in =
|
||||||
|
rabit::io::CreateInputSplit(uri, rank, npart);
|
||||||
this->Clear();
|
this->Clear();
|
||||||
FILE* file;
|
std::string line;
|
||||||
if (!strcmp(fname, "stdin")) {
|
while (in->NextLine(&line)) {
|
||||||
file = stdin;
|
float label;
|
||||||
} else {
|
std::istringstream ss(line);
|
||||||
file = utils::FopenCheck(fname, "r");
|
std::vector<RowBatch::Entry> feats;
|
||||||
}
|
ss >> label;
|
||||||
float label; bool init = true;
|
while (!ss.eof()) {
|
||||||
char tmp[1024];
|
RowBatch::Entry e;
|
||||||
std::vector<RowBatch::Entry> feats;
|
if (!(ss >> e.index)) break;
|
||||||
while (fscanf(file, "%s", tmp) == 1) {
|
ss.ignore(32, ':');
|
||||||
RowBatch::Entry e;
|
if (!(ss >> e.fvalue)) break;
|
||||||
if (sscanf(tmp, "%u:%f", &e.index, &e.fvalue) == 2) {
|
|
||||||
feats.push_back(e);
|
feats.push_back(e);
|
||||||
} else {
|
|
||||||
if (!init) {
|
|
||||||
info.labels.push_back(label);
|
|
||||||
this->AddRow(feats);
|
|
||||||
}
|
|
||||||
feats.clear();
|
|
||||||
utils::Check(sscanf(tmp, "%f", &label) == 1, "invalid LibSVM format");
|
|
||||||
init = false;
|
|
||||||
}
|
}
|
||||||
|
info.labels.push_back(label);
|
||||||
|
this->AddRow(feats);
|
||||||
}
|
}
|
||||||
|
delete in;
|
||||||
info.labels.push_back(label);
|
|
||||||
this->AddRow(feats);
|
|
||||||
|
|
||||||
if (!silent) {
|
if (!silent) {
|
||||||
utils::Printf("%lux%lu matrix with %lu entries is loaded from %s\n",
|
utils::Printf("%lux%lu matrix with %lu entries is loaded from %s\n",
|
||||||
static_cast<unsigned long>(info.num_row()),
|
static_cast<unsigned long>(info.num_row()),
|
||||||
static_cast<unsigned long>(info.num_col()),
|
static_cast<unsigned long>(info.num_col()),
|
||||||
static_cast<unsigned long>(row_data_.size()), fname);
|
static_cast<unsigned long>(row_data_.size()), uri);
|
||||||
}
|
|
||||||
if (file != stdin) {
|
|
||||||
fclose(file);
|
|
||||||
}
|
}
|
||||||
// try to load in additional file
|
// try to load in additional file
|
||||||
std::string name = fname;
|
if (!loadsplit) {
|
||||||
std::string gname = name + ".group";
|
std::string name = uri;
|
||||||
if (info.TryLoadGroup(gname.c_str(), silent)) {
|
std::string gname = name + ".group";
|
||||||
utils::Check(info.group_ptr.back() == info.num_row(),
|
if (info.TryLoadGroup(gname.c_str(), silent)) {
|
||||||
"DMatrix: group data does not match the number of rows in features");
|
utils::Check(info.group_ptr.back() == info.num_row(),
|
||||||
}
|
"DMatrix: group data does not match the number of rows in features");
|
||||||
std::string wname = name + ".weight";
|
}
|
||||||
if (info.TryLoadFloatInfo("weight", wname.c_str(), silent)) {
|
std::string wname = name + ".weight";
|
||||||
utils::Check(info.weights.size() == info.num_row(),
|
if (info.TryLoadFloatInfo("weight", wname.c_str(), silent)) {
|
||||||
"DMatrix: weight data does not match the number of rows in features");
|
utils::Check(info.weights.size() == info.num_row(),
|
||||||
}
|
"DMatrix: weight data does not match the number of rows in features");
|
||||||
std::string mname = name + ".base_margin";
|
}
|
||||||
if (info.TryLoadFloatInfo("base_margin", mname.c_str(), silent)) {
|
std::string mname = name + ".base_margin";
|
||||||
|
if (info.TryLoadFloatInfo("base_margin", mname.c_str(), silent)) {
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -12,7 +12,6 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include "../sync/sync.h"
|
#include "../sync/sync.h"
|
||||||
#include "../utils/io.h"
|
#include "../utils/io.h"
|
||||||
#include "../utils/base64.h"
|
|
||||||
#include "./objective.h"
|
#include "./objective.h"
|
||||||
#include "./evaluation.h"
|
#include "./evaluation.h"
|
||||||
#include "../gbm/gbm.h"
|
#include "../gbm/gbm.h"
|
||||||
@ -178,44 +177,37 @@ class BoostLearner : public rabit::ISerializable {
|
|||||||
}
|
}
|
||||||
// rabit load model from rabit checkpoint
|
// rabit load model from rabit checkpoint
|
||||||
virtual void Load(rabit::IStream &fi) {
|
virtual void Load(rabit::IStream &fi) {
|
||||||
RabitStreamAdapter fs(fi);
|
|
||||||
// for row split, we should not keep pbuffer
|
// for row split, we should not keep pbuffer
|
||||||
this->LoadModel(fs, distributed_mode != 2, false);
|
this->LoadModel(fi, distributed_mode != 2, false);
|
||||||
}
|
}
|
||||||
// rabit save model to rabit checkpoint
|
// rabit save model to rabit checkpoint
|
||||||
virtual void Save(rabit::IStream &fo) const {
|
virtual void Save(rabit::IStream &fo) const {
|
||||||
RabitStreamAdapter fs(fo);
|
|
||||||
// for row split, we should not keep pbuffer
|
// for row split, we should not keep pbuffer
|
||||||
this->SaveModel(fs, distributed_mode != 2);
|
this->SaveModel(fo, distributed_mode != 2);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief load model from file
|
* \brief load model from file
|
||||||
* \param fname file name
|
* \param fname file name
|
||||||
*/
|
*/
|
||||||
inline void LoadModel(const char *fname) {
|
inline void LoadModel(const char *fname) {
|
||||||
FILE *fp = utils::FopenCheck(fname, "rb");
|
utils::IStream *fi = rabit::io::CreateStream(fname, "r");
|
||||||
utils::FileStream fi(fp);
|
|
||||||
std::string header; header.resize(4);
|
std::string header; header.resize(4);
|
||||||
// check header for different binary encode
|
// check header for different binary encode
|
||||||
// can be base64 or binary
|
// can be base64 or binary
|
||||||
if (fi.Read(&header[0], 4) != 0) {
|
utils::Check(fi->Read(&header[0], 4) != 0, "invalid model");
|
||||||
// base64 format
|
// base64 format
|
||||||
if (header == "bs64") {
|
if (header == "bs64") {
|
||||||
utils::Base64InStream bsin(fp);
|
utils::Base64InStream bsin(fi);
|
||||||
bsin.InitPosition();
|
bsin.InitPosition();
|
||||||
this->LoadModel(bsin);
|
this->LoadModel(bsin);
|
||||||
fclose(fp);
|
} else if (header == "binf") {
|
||||||
return;
|
this->LoadModel(*fi);
|
||||||
}
|
} else {
|
||||||
if (header == "binf") {
|
delete fi;
|
||||||
this->LoadModel(fi);
|
fi = rabit::io::CreateStream(fname, "r");
|
||||||
fclose(fp);
|
this->LoadModel(*fi);
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
fi.Seek(0);
|
delete fi;
|
||||||
this->LoadModel(fi);
|
|
||||||
fclose(fp);
|
|
||||||
}
|
}
|
||||||
inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const {
|
inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const {
|
||||||
fo.Write(&mparam, sizeof(ModelParam));
|
fo.Write(&mparam, sizeof(ModelParam));
|
||||||
@ -226,33 +218,20 @@ class BoostLearner : public rabit::ISerializable {
|
|||||||
/*!
|
/*!
|
||||||
* \brief save model into file
|
* \brief save model into file
|
||||||
* \param fname file name
|
* \param fname file name
|
||||||
|
* \param save_base64 whether save in base64 format
|
||||||
*/
|
*/
|
||||||
inline void SaveModel(const char *fname) const {
|
inline void SaveModel(const char *fname, bool save_base64 = false) const {
|
||||||
FILE *fp;
|
utils::IStream *fo = rabit::io::CreateStream(fname, "w");
|
||||||
bool use_stdout = false;;
|
if (save_base64 != 0 || !strcmp(fname, "stdout")) {
|
||||||
#ifndef XGBOOST_STRICT_CXX98_
|
fo->Write("bs64\t", 5);
|
||||||
if (!strcmp(fname, "stdout")) {
|
utils::Base64OutStream bout(fo);
|
||||||
fp = stdout;
|
|
||||||
use_stdout = true;
|
|
||||||
} else
|
|
||||||
#endif
|
|
||||||
{
|
|
||||||
fp = utils::FopenCheck(fname, "wb");
|
|
||||||
}
|
|
||||||
utils::FileStream fo(fp);
|
|
||||||
std::string header;
|
|
||||||
if (save_base64 != 0|| use_stdout) {
|
|
||||||
fo.Write("bs64\t", 5);
|
|
||||||
utils::Base64OutStream bout(fp);
|
|
||||||
this->SaveModel(bout);
|
this->SaveModel(bout);
|
||||||
bout.Finish('\n');
|
bout.Finish('\n');
|
||||||
} else {
|
} else {
|
||||||
fo.Write("binf", 4);
|
fo->Write("binf", 4);
|
||||||
this->SaveModel(fo);
|
this->SaveModel(*fo);
|
||||||
}
|
|
||||||
if (!use_stdout) {
|
|
||||||
fclose(fp);
|
|
||||||
}
|
}
|
||||||
|
delete fo;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief check if data matrix is ready to be used by training,
|
* \brief check if data matrix is ready to be used by training,
|
||||||
@ -512,23 +491,6 @@ class BoostLearner : public rabit::ISerializable {
|
|||||||
// data structure field
|
// data structure field
|
||||||
/*! \brief the entries indicates that we have internal prediction cache */
|
/*! \brief the entries indicates that we have internal prediction cache */
|
||||||
std::vector<CacheEntry> cache_;
|
std::vector<CacheEntry> cache_;
|
||||||
|
|
||||||
private:
|
|
||||||
// adapt rabit stream to utils stream
|
|
||||||
struct RabitStreamAdapter : public utils::IStream {
|
|
||||||
// rabit stream
|
|
||||||
rabit::IStream &fs;
|
|
||||||
// constructr
|
|
||||||
RabitStreamAdapter(rabit::IStream &fs) : fs(fs) {}
|
|
||||||
// destructor
|
|
||||||
virtual ~RabitStreamAdapter(void){}
|
|
||||||
virtual size_t Read(void *ptr, size_t size) {
|
|
||||||
return fs.Read(ptr, size);
|
|
||||||
}
|
|
||||||
virtual void Write(const void *ptr, size_t size) {
|
|
||||||
fs.Write(ptr, size);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
};
|
};
|
||||||
} // namespace learner
|
} // namespace learner
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -7,6 +7,7 @@
|
|||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#include "../../subtree/rabit/include/rabit.h"
|
#include "../../subtree/rabit/include/rabit.h"
|
||||||
|
#include "../../subtree/rabit/rabit-learn/io/io.h"
|
||||||
#endif // XGBOOST_SYNC_H_
|
#endif // XGBOOST_SYNC_H_
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,205 +0,0 @@
|
|||||||
#ifndef XGBOOST_UTILS_BASE64_H_
|
|
||||||
#define XGBOOST_UTILS_BASE64_H_
|
|
||||||
/*!
|
|
||||||
* \file base64.h
|
|
||||||
* \brief data stream support to input and output from/to base64 stream
|
|
||||||
* base64 is easier to store and pass as text format in mapreduce
|
|
||||||
* \author Tianqi Chen
|
|
||||||
*/
|
|
||||||
#include <cctype>
|
|
||||||
#include <cstdio>
|
|
||||||
#include "./utils.h"
|
|
||||||
#include "./io.h"
|
|
||||||
|
|
||||||
namespace xgboost {
|
|
||||||
namespace utils {
|
|
||||||
/*! \brief namespace of base64 decoding and encoding table */
|
|
||||||
namespace base64 {
|
|
||||||
const char DecodeTable[] = {
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
62, // '+'
|
|
||||||
0, 0, 0,
|
|
||||||
63, // '/'
|
|
||||||
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
|
|
||||||
0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
|
||||||
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
|
|
||||||
0, 0, 0, 0, 0, 0,
|
|
||||||
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
|
|
||||||
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
|
|
||||||
};
|
|
||||||
static const char EncodeTable[] =
|
|
||||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
|
||||||
} // namespace base64
|
|
||||||
/*! \brief the stream that reads from base64, note we take from file pointers */
|
|
||||||
class Base64InStream: public IStream {
|
|
||||||
public:
|
|
||||||
explicit Base64InStream(FILE *fp) : fp(fp) {
|
|
||||||
num_prev = 0; tmp_ch = 0;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief initialize the stream position to beginning of next base64 stream
|
|
||||||
* call this function before actually start read
|
|
||||||
*/
|
|
||||||
inline void InitPosition(void) {
|
|
||||||
// get a charater
|
|
||||||
do {
|
|
||||||
tmp_ch = fgetc(fp);
|
|
||||||
} while (isspace(tmp_ch));
|
|
||||||
}
|
|
||||||
/*! \brief whether current position is end of a base64 stream */
|
|
||||||
inline bool IsEOF(void) const {
|
|
||||||
return num_prev == 0 && (tmp_ch == EOF || isspace(tmp_ch));
|
|
||||||
}
|
|
||||||
virtual size_t Read(void *ptr, size_t size) {
|
|
||||||
using base64::DecodeTable;
|
|
||||||
if (size == 0) return 0;
|
|
||||||
// use tlen to record left size
|
|
||||||
size_t tlen = size;
|
|
||||||
unsigned char *cptr = static_cast<unsigned char*>(ptr);
|
|
||||||
// if anything left, load from previous buffered result
|
|
||||||
if (num_prev != 0) {
|
|
||||||
if (num_prev == 2) {
|
|
||||||
if (tlen >= 2) {
|
|
||||||
*cptr++ = buf_prev[0];
|
|
||||||
*cptr++ = buf_prev[1];
|
|
||||||
tlen -= 2;
|
|
||||||
num_prev = 0;
|
|
||||||
} else {
|
|
||||||
// assert tlen == 1
|
|
||||||
*cptr++ = buf_prev[0]; --tlen;
|
|
||||||
buf_prev[0] = buf_prev[1];
|
|
||||||
num_prev = 1;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// assert num_prev == 1
|
|
||||||
*cptr++ = buf_prev[0]; --tlen; num_prev = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (tlen == 0) return size;
|
|
||||||
int nvalue;
|
|
||||||
// note: everything goes with 4 bytes in Base64
|
|
||||||
// so we process 4 bytes a unit
|
|
||||||
while (tlen && tmp_ch != EOF && !isspace(tmp_ch)) {
|
|
||||||
// first byte
|
|
||||||
nvalue = DecodeTable[tmp_ch] << 18;
|
|
||||||
{
|
|
||||||
// second byte
|
|
||||||
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
|
|
||||||
"invalid base64 format");
|
|
||||||
nvalue |= DecodeTable[tmp_ch] << 12;
|
|
||||||
*cptr++ = (nvalue >> 16) & 0xFF; --tlen;
|
|
||||||
}
|
|
||||||
{
|
|
||||||
// third byte
|
|
||||||
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
|
|
||||||
"invalid base64 format");
|
|
||||||
// handle termination
|
|
||||||
if (tmp_ch == '=') {
|
|
||||||
Check((tmp_ch = fgetc(fp), tmp_ch == '='), "invalid base64 format");
|
|
||||||
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)),
|
|
||||||
"invalid base64 format");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
nvalue |= DecodeTable[tmp_ch] << 6;
|
|
||||||
if (tlen) {
|
|
||||||
*cptr++ = (nvalue >> 8) & 0xFF; --tlen;
|
|
||||||
} else {
|
|
||||||
buf_prev[num_prev++] = (nvalue >> 8) & 0xFF;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
{
|
|
||||||
// fourth byte
|
|
||||||
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
|
|
||||||
"invalid base64 format");
|
|
||||||
if (tmp_ch == '=') {
|
|
||||||
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)),
|
|
||||||
"invalid base64 format");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
nvalue |= DecodeTable[tmp_ch];
|
|
||||||
if (tlen) {
|
|
||||||
*cptr++ = nvalue & 0xFF; --tlen;
|
|
||||||
} else {
|
|
||||||
buf_prev[num_prev ++] = nvalue & 0xFF;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// get next char
|
|
||||||
tmp_ch = fgetc(fp);
|
|
||||||
}
|
|
||||||
if (kStrictCheck) {
|
|
||||||
Check(tlen == 0, "Base64InStream: read incomplete");
|
|
||||||
}
|
|
||||||
return size - tlen;
|
|
||||||
}
|
|
||||||
virtual void Write(const void *ptr, size_t size) {
|
|
||||||
utils::Error("Base64InStream do not support write");
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
FILE *fp;
|
|
||||||
int tmp_ch;
|
|
||||||
int num_prev;
|
|
||||||
unsigned char buf_prev[2];
|
|
||||||
// whether we need to do strict check
|
|
||||||
static const bool kStrictCheck = false;
|
|
||||||
};
|
|
||||||
/*! \brief the stream that write to base64, note we take from file pointers */
|
|
||||||
class Base64OutStream: public IStream {
|
|
||||||
public:
|
|
||||||
explicit Base64OutStream(FILE *fp) : fp(fp) {
|
|
||||||
buf_top = 0;
|
|
||||||
}
|
|
||||||
virtual void Write(const void *ptr, size_t size) {
|
|
||||||
using base64::EncodeTable;
|
|
||||||
size_t tlen = size;
|
|
||||||
const unsigned char *cptr = static_cast<const unsigned char*>(ptr);
|
|
||||||
while (tlen) {
|
|
||||||
while (buf_top < 3 && tlen != 0) {
|
|
||||||
buf[++buf_top] = *cptr++; --tlen;
|
|
||||||
}
|
|
||||||
if (buf_top == 3) {
|
|
||||||
// flush 4 bytes out
|
|
||||||
fputc(EncodeTable[buf[1] >> 2], fp);
|
|
||||||
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp);
|
|
||||||
fputc(EncodeTable[((buf[2] << 2) | (buf[3] >> 6)) & 0x3F], fp);
|
|
||||||
fputc(EncodeTable[buf[3] & 0x3F], fp);
|
|
||||||
buf_top = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
virtual size_t Read(void *ptr, size_t size) {
|
|
||||||
Error("Base64OutStream do not support read");
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief finish writing of all current base64 stream, do some post processing
|
|
||||||
* \param endch charater to put to end of stream, if it is EOF, then nothing will be done
|
|
||||||
*/
|
|
||||||
inline void Finish(char endch = EOF) {
|
|
||||||
using base64::EncodeTable;
|
|
||||||
if (buf_top == 1) {
|
|
||||||
fputc(EncodeTable[buf[1] >> 2], fp);
|
|
||||||
fputc(EncodeTable[(buf[1] << 4) & 0x3F], fp);
|
|
||||||
fputc('=', fp);
|
|
||||||
fputc('=', fp);
|
|
||||||
}
|
|
||||||
if (buf_top == 2) {
|
|
||||||
fputc(EncodeTable[buf[1] >> 2], fp);
|
|
||||||
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp);
|
|
||||||
fputc(EncodeTable[(buf[2] << 2) & 0x3F], fp);
|
|
||||||
fputc('=', fp);
|
|
||||||
}
|
|
||||||
buf_top = 0;
|
|
||||||
if (endch != EOF) fputc(endch, fp);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
FILE *fp;
|
|
||||||
int buf_top;
|
|
||||||
unsigned char buf[4];
|
|
||||||
};
|
|
||||||
} // namespace utils
|
|
||||||
} // namespace xgboost
|
|
||||||
#endif // XGBOOST_UTILS_BASE64_H_
|
|
||||||
173
src/utils/io.h
173
src/utils/io.h
@ -5,6 +5,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include "./utils.h"
|
#include "./utils.h"
|
||||||
|
#include "../sync/sync.h"
|
||||||
/*!
|
/*!
|
||||||
* \file io.h
|
* \file io.h
|
||||||
* \brief general stream interface for serialization, I/O
|
* \brief general stream interface for serialization, I/O
|
||||||
@ -12,168 +13,13 @@
|
|||||||
*/
|
*/
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace utils {
|
namespace utils {
|
||||||
/*!
|
// reuse the definitions of streams
|
||||||
* \brief interface of stream I/O, used to serialize model
|
typedef rabit::IStream IStream;
|
||||||
*/
|
typedef rabit::utils::ISeekStream ISeekStream;
|
||||||
class IStream {
|
typedef rabit::utils::MemoryFixSizeBuffer MemoryFixSizeBuffer;
|
||||||
public:
|
typedef rabit::utils::MemoryBufferStream MemoryBufferStream;
|
||||||
/*!
|
typedef rabit::io::Base64InStream Base64InStream;
|
||||||
* \brief read data from stream
|
typedef rabit::io::Base64OutStream Base64OutStream;
|
||||||
* \param ptr pointer to memory buffer
|
|
||||||
* \param size size of block
|
|
||||||
* \return usually is the size of data readed
|
|
||||||
*/
|
|
||||||
virtual size_t Read(void *ptr, size_t size) = 0;
|
|
||||||
/*!
|
|
||||||
* \brief write data to stream
|
|
||||||
* \param ptr pointer to memory buffer
|
|
||||||
* \param size size of block
|
|
||||||
*/
|
|
||||||
virtual void Write(const void *ptr, size_t size) = 0;
|
|
||||||
/*! \brief virtual destructor */
|
|
||||||
virtual ~IStream(void) {}
|
|
||||||
|
|
||||||
public:
|
|
||||||
// helper functions to write various of data structures
|
|
||||||
/*!
|
|
||||||
* \brief binary serialize a vector
|
|
||||||
* \param vec vector to be serialized
|
|
||||||
*/
|
|
||||||
template<typename T>
|
|
||||||
inline void Write(const std::vector<T> &vec) {
|
|
||||||
uint64_t sz = static_cast<uint64_t>(vec.size());
|
|
||||||
this->Write(&sz, sizeof(sz));
|
|
||||||
if (sz != 0) {
|
|
||||||
this->Write(&vec[0], sizeof(T) * sz);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief binary load a vector
|
|
||||||
* \param out_vec vector to be loaded
|
|
||||||
* \return whether load is successfull
|
|
||||||
*/
|
|
||||||
template<typename T>
|
|
||||||
inline bool Read(std::vector<T> *out_vec) {
|
|
||||||
uint64_t sz;
|
|
||||||
if (this->Read(&sz, sizeof(sz)) == 0) return false;
|
|
||||||
out_vec->resize(sz);
|
|
||||||
if (sz != 0) {
|
|
||||||
if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief binary serialize a string
|
|
||||||
* \param str the string to be serialized
|
|
||||||
*/
|
|
||||||
inline void Write(const std::string &str) {
|
|
||||||
uint64_t sz = static_cast<uint64_t>(str.length());
|
|
||||||
this->Write(&sz, sizeof(sz));
|
|
||||||
if (sz != 0) {
|
|
||||||
this->Write(&str[0], sizeof(char) * sz);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief binary load a string
|
|
||||||
* \param out_str string to be loaded
|
|
||||||
* \return whether load is successful
|
|
||||||
*/
|
|
||||||
inline bool Read(std::string *out_str) {
|
|
||||||
uint64_t sz;
|
|
||||||
if (this->Read(&sz, sizeof(sz)) == 0) return false;
|
|
||||||
out_str->resize(sz);
|
|
||||||
if (sz != 0) {
|
|
||||||
if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*! \brief interface of i/o stream that support seek */
|
|
||||||
class ISeekStream: public IStream {
|
|
||||||
public:
|
|
||||||
/*! \brief seek to certain position of the file */
|
|
||||||
virtual void Seek(size_t pos) = 0;
|
|
||||||
/*! \brief tell the position of the stream */
|
|
||||||
virtual size_t Tell(void) = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
/*! \brief fixed size memory buffer */
|
|
||||||
struct MemoryFixSizeBuffer : public ISeekStream {
|
|
||||||
public:
|
|
||||||
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
|
|
||||||
: p_buffer_(reinterpret_cast<char*>(p_buffer)), buffer_size_(buffer_size) {
|
|
||||||
curr_ptr_ = 0;
|
|
||||||
}
|
|
||||||
virtual ~MemoryFixSizeBuffer(void) {}
|
|
||||||
virtual size_t Read(void *ptr, size_t size) {
|
|
||||||
utils::Assert(curr_ptr_ + size <= buffer_size_,
|
|
||||||
"read can not have position excceed buffer length");
|
|
||||||
size_t nread = std::min(buffer_size_ - curr_ptr_, size);
|
|
||||||
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
|
|
||||||
curr_ptr_ += nread;
|
|
||||||
return nread;
|
|
||||||
}
|
|
||||||
virtual void Write(const void *ptr, size_t size) {
|
|
||||||
if (size == 0) return;
|
|
||||||
utils::Assert(curr_ptr_ + size <= buffer_size_,
|
|
||||||
"write position exceed fixed buffer size");
|
|
||||||
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
|
|
||||||
curr_ptr_ += size;
|
|
||||||
}
|
|
||||||
virtual void Seek(size_t pos) {
|
|
||||||
curr_ptr_ = static_cast<size_t>(pos);
|
|
||||||
}
|
|
||||||
virtual size_t Tell(void) {
|
|
||||||
return curr_ptr_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
/*! \brief in memory buffer */
|
|
||||||
char *p_buffer_;
|
|
||||||
/*! \brief current pointer */
|
|
||||||
size_t buffer_size_;
|
|
||||||
/*! \brief current pointer */
|
|
||||||
size_t curr_ptr_;
|
|
||||||
}; // class MemoryFixSizeBuffer
|
|
||||||
|
|
||||||
/*! \brief a in memory buffer that can be read and write as stream interface */
|
|
||||||
struct MemoryBufferStream : public ISeekStream {
|
|
||||||
public:
|
|
||||||
MemoryBufferStream(std::string *p_buffer)
|
|
||||||
: p_buffer_(p_buffer) {
|
|
||||||
curr_ptr_ = 0;
|
|
||||||
}
|
|
||||||
virtual ~MemoryBufferStream(void) {}
|
|
||||||
virtual size_t Read(void *ptr, size_t size) {
|
|
||||||
utils::Assert(curr_ptr_ <= p_buffer_->length(),
|
|
||||||
"read can not have position excceed buffer length");
|
|
||||||
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
|
|
||||||
if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
|
|
||||||
curr_ptr_ += nread;
|
|
||||||
return nread;
|
|
||||||
}
|
|
||||||
virtual void Write(const void *ptr, size_t size) {
|
|
||||||
if (size == 0) return;
|
|
||||||
if (curr_ptr_ + size > p_buffer_->length()) {
|
|
||||||
p_buffer_->resize(curr_ptr_+size);
|
|
||||||
}
|
|
||||||
std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
|
|
||||||
curr_ptr_ += size;
|
|
||||||
}
|
|
||||||
virtual void Seek(size_t pos) {
|
|
||||||
curr_ptr_ = static_cast<size_t>(pos);
|
|
||||||
}
|
|
||||||
virtual size_t Tell(void) {
|
|
||||||
return curr_ptr_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
/*! \brief in memory buffer */
|
|
||||||
std::string *p_buffer_;
|
|
||||||
/*! \brief current pointer */
|
|
||||||
size_t curr_ptr_;
|
|
||||||
}; // class MemoryBufferStream
|
|
||||||
|
|
||||||
/*! \brief implementation of file i/o stream */
|
/*! \brief implementation of file i/o stream */
|
||||||
class FileStream : public ISeekStream {
|
class FileStream : public ISeekStream {
|
||||||
@ -194,6 +40,9 @@ class FileStream : public ISeekStream {
|
|||||||
virtual size_t Tell(void) {
|
virtual size_t Tell(void) {
|
||||||
return std::ftell(fp);
|
return std::ftell(fp);
|
||||||
}
|
}
|
||||||
|
virtual bool AtEnd(void) const {
|
||||||
|
return std::feof(fp);
|
||||||
|
}
|
||||||
inline void Close(void) {
|
inline void Close(void) {
|
||||||
if (fp != NULL){
|
if (fp != NULL){
|
||||||
std::fclose(fp); fp = NULL;
|
std::fclose(fp); fp = NULL;
|
||||||
|
|||||||
@ -36,14 +36,8 @@ class BoostLearnTask {
|
|||||||
this->SetParam("silent", "1");
|
this->SetParam("silent", "1");
|
||||||
save_period = 0;
|
save_period = 0;
|
||||||
}
|
}
|
||||||
// whether need data rank
|
// initialized the result
|
||||||
bool need_data_rank = strchr(train_path.c_str(), '%') != NULL;
|
rabit::Init(argc, argv);
|
||||||
// if need data rank in loading, initialize rabit engine before load data
|
|
||||||
// otherwise, initialize rabit engine after loading data
|
|
||||||
// lazy initialization of rabit engine can be helpful in speculative execution
|
|
||||||
if (need_data_rank) rabit::Init(argc, argv);
|
|
||||||
this->InitData();
|
|
||||||
if (!need_data_rank) rabit::Init(argc, argv);
|
|
||||||
if (rabit::IsDistributed()) {
|
if (rabit::IsDistributed()) {
|
||||||
std::string pname = rabit::GetProcessorName();
|
std::string pname = rabit::GetProcessorName();
|
||||||
fprintf(stderr, "start %s:%d\n", pname.c_str(), rabit::GetRank());
|
fprintf(stderr, "start %s:%d\n", pname.c_str(), rabit::GetRank());
|
||||||
@ -54,6 +48,8 @@ class BoostLearnTask {
|
|||||||
if (rabit::GetRank() != 0) {
|
if (rabit::GetRank() != 0) {
|
||||||
this->SetParam("silent", "2");
|
this->SetParam("silent", "2");
|
||||||
}
|
}
|
||||||
|
this->InitData();
|
||||||
|
|
||||||
if (task == "train") {
|
if (task == "train") {
|
||||||
// if task is training, will try recover from checkpoint
|
// if task is training, will try recover from checkpoint
|
||||||
this->TaskTrain();
|
this->TaskTrain();
|
||||||
@ -135,17 +131,22 @@ class BoostLearnTask {
|
|||||||
train_path = s_tmp;
|
train_path = s_tmp;
|
||||||
load_part = 1;
|
load_part = 1;
|
||||||
}
|
}
|
||||||
|
bool loadsplit = data_split == "row";
|
||||||
if (name_fmap != "NULL") fmap.LoadText(name_fmap.c_str());
|
if (name_fmap != "NULL") fmap.LoadText(name_fmap.c_str());
|
||||||
if (task == "dump") return;
|
if (task == "dump") return;
|
||||||
if (task == "pred") {
|
if (task == "pred") {
|
||||||
data = io::LoadDataMatrix(test_path.c_str(), silent != 0, use_buffer != 0);
|
data = io::LoadDataMatrix(test_path.c_str(), silent != 0, use_buffer != 0, loadsplit);
|
||||||
} else {
|
} else {
|
||||||
// training
|
// training
|
||||||
data = io::LoadDataMatrix(train_path.c_str(), silent != 0 && load_part == 0, use_buffer != 0);
|
data = io::LoadDataMatrix(train_path.c_str(),
|
||||||
|
silent != 0 && load_part == 0,
|
||||||
|
use_buffer != 0, loadsplit);
|
||||||
utils::Assert(eval_data_names.size() == eval_data_paths.size(), "BUG");
|
utils::Assert(eval_data_names.size() == eval_data_paths.size(), "BUG");
|
||||||
for (size_t i = 0; i < eval_data_names.size(); ++i) {
|
for (size_t i = 0; i < eval_data_names.size(); ++i) {
|
||||||
deval.push_back(io::LoadDataMatrix(eval_data_paths[i].c_str(), silent != 0, use_buffer != 0));
|
deval.push_back(io::LoadDataMatrix(eval_data_paths[i].c_str(),
|
||||||
|
silent != 0,
|
||||||
|
use_buffer != 0,
|
||||||
|
loadsplit));
|
||||||
devalall.push_back(deval.back());
|
devalall.push_back(deval.back());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,13 @@ import rabit_tracker as tracker
|
|||||||
WRAPPER_PATH = os.path.dirname(__file__) + '/../wrapper'
|
WRAPPER_PATH = os.path.dirname(__file__) + '/../wrapper'
|
||||||
YARN_JAR_PATH = os.path.dirname(__file__) + '/../yarn/rabit-yarn.jar'
|
YARN_JAR_PATH = os.path.dirname(__file__) + '/../yarn/rabit-yarn.jar'
|
||||||
|
|
||||||
assert os.path.exists(YARN_JAR_PATH), ("cannot find \"%s\", please run build.sh on the yarn folder" % YARN_JAR_PATH)
|
if not os.path.exists(YARN_JAR_PATH):
|
||||||
|
warnings.warn("cannot find \"%s\", I will try to run build" % YARN_JAR_PATH)
|
||||||
|
cmd = 'cd %;./build.sh' % os.path.dirname(__file__) + '/../yarn/'
|
||||||
|
print cmd
|
||||||
|
subprocess.check_call(cmd, shell = True, env = os.environ)
|
||||||
|
assert os.path.exists(YARN_JAR_PATH), "failed to build rabit-yarn.jar, try it manually"
|
||||||
|
|
||||||
hadoop_binary = 'hadoop'
|
hadoop_binary = 'hadoop'
|
||||||
# code
|
# code
|
||||||
hadoop_home = os.getenv('HADOOP_HOME')
|
hadoop_home = os.getenv('HADOOP_HOME')
|
||||||
|
|||||||
@ -1,4 +1,8 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
if [ -z "$HADOOP_PREFIX" ]; then
|
||||||
|
echo "cannot found $HADOOP_PREFIX in the environment variable, please set it properly"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
CPATH=`${HADOOP_PREFIX}/bin/hadoop classpath`
|
CPATH=`${HADOOP_PREFIX}/bin/hadoop classpath`
|
||||||
javac -cp $CPATH -d bin src/org/apache/hadoop/yarn/rabit/*
|
javac -cp $CPATH -d bin src/org/apache/hadoop/yarn/rabit/*
|
||||||
jar cf rabit-yarn.jar -C bin .
|
jar cf rabit-yarn.jar -C bin .
|
||||||
|
|||||||
@ -112,7 +112,7 @@ using namespace xgboost::wrapper;
|
|||||||
|
|
||||||
extern "C"{
|
extern "C"{
|
||||||
void* XGDMatrixCreateFromFile(const char *fname, int silent) {
|
void* XGDMatrixCreateFromFile(const char *fname, int silent) {
|
||||||
return LoadDataMatrix(fname, silent != 0, false);
|
return LoadDataMatrix(fname, silent != 0, false, false);
|
||||||
}
|
}
|
||||||
void* XGDMatrixCreateFromCSR(const bst_ulong *indptr,
|
void* XGDMatrixCreateFromCSR(const bst_ulong *indptr,
|
||||||
const unsigned *indices,
|
const unsigned *indices,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user