Merge commit '75bf97b57539e5572e7ae8eba72bac6562c63c07'

Conflicts:
	subtree/rabit/rabit-learn/io/line_split-inl.h
	subtree/rabit/yarn/build.sh
This commit is contained in:
tqchen 2015-03-21 00:48:34 -07:00
commit 9ccbeaa8f0
34 changed files with 856 additions and 201 deletions

View File

@ -19,6 +19,8 @@ namespace utils {
/*! \brief interface of i/o stream that support seek */ /*! \brief interface of i/o stream that support seek */
class ISeekStream: public IStream { class ISeekStream: public IStream {
public: public:
// virtual destructor
virtual ~ISeekStream(void) {}
/*! \brief seek to certain position of the file */ /*! \brief seek to certain position of the file */
virtual void Seek(size_t pos) = 0; virtual void Seek(size_t pos) = 0;
/*! \brief tell the position of the stream */ /*! \brief tell the position of the stream */

2
subtree/rabit/rabit-learn/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
config.mk
*.log

View File

@ -38,6 +38,7 @@ class StreamBufferReader {
} }
} }
} }
/*! \brief whether we are reaching the end of file */
inline bool AtEnd(void) const { inline bool AtEnd(void) const {
return read_len_ == 0; return read_len_ == 0;
} }

View File

@ -66,27 +66,36 @@ class FileStream : public utils::ISeekStream {
}; };
/*! \brief line split from normal file system */ /*! \brief line split from normal file system */
class FileSplit : public LineSplitBase { class FileProvider : public LineSplitter::IFileProvider {
public: public:
explicit FileSplit(const char *uri, unsigned rank, unsigned nsplit) { explicit FileProvider(const char *uri) {
LineSplitBase::SplitNames(&fnames_, uri, "#"); LineSplitter::SplitNames(&fnames_, uri, "#");
std::vector<size_t> fsize; std::vector<size_t> fsize;
for (size_t i = 0; i < fnames_.size(); ++i) { for (size_t i = 0; i < fnames_.size(); ++i) {
if (!std::strncmp(fnames_[i].c_str(), "file://", 7)) { if (!std::strncmp(fnames_[i].c_str(), "file://", 7)) {
std::string tmp = fnames_[i].c_str() + 7; std::string tmp = fnames_[i].c_str() + 7;
fnames_[i] = tmp; fnames_[i] = tmp;
} }
fsize.push_back(GetFileSize(fnames_[i].c_str())); size_t fz = GetFileSize(fnames_[i].c_str());
if (fz != 0) {
fsize_.push_back(fz);
}
} }
LineSplitBase::Init(fsize, rank, nsplit);
} }
virtual ~FileSplit(void) {} // destrucor
virtual ~FileProvider(void) {}
protected: virtual utils::ISeekStream *Open(size_t file_index) {
virtual utils::ISeekStream *GetFile(size_t file_index) {
utils::Assert(file_index < fnames_.size(), "file index exceed bound"); utils::Assert(file_index < fnames_.size(), "file index exceed bound");
return new FileStream(fnames_[file_index].c_str(), "rb"); return new FileStream(fnames_[file_index].c_str(), "rb");
} }
virtual const std::vector<size_t> &FileSize(void) const {
return fsize_;
}
private:
// file sizes
std::vector<size_t> fsize_;
// file names
std::vector<std::string> fnames_;
// get file size // get file size
inline static size_t GetFileSize(const char *fname) { inline static size_t GetFileSize(const char *fname) {
std::FILE *fp = utils::FopenCheck(fname, "rb"); std::FILE *fp = utils::FopenCheck(fname, "rb");
@ -96,10 +105,6 @@ class FileSplit : public LineSplitBase {
std::fclose(fp); std::fclose(fp);
return fsize; return fsize;
} }
private:
// file names
std::vector<std::string> fnames_;
}; };
} // namespace io } // namespace io
} // namespace rabit } // namespace rabit

View File

@ -6,6 +6,7 @@
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#include <string> #include <string>
#include <cstdlib>
#include <vector> #include <vector>
#include <hdfs.h> #include <hdfs.h>
#include <errno.h> #include <errno.h>
@ -15,11 +16,15 @@
/*! \brief io interface */ /*! \brief io interface */
namespace rabit { namespace rabit {
namespace io { namespace io {
class HDFSStream : public utils::ISeekStream { class HDFSStream : public ISeekStream {
public: public:
HDFSStream(hdfsFS fs, const char *fname, const char *mode) HDFSStream(hdfsFS fs,
: fs_(fs), at_end_(false) { const char *fname,
int flag; const char *mode,
bool disconnect_when_done)
: fs_(fs), at_end_(false),
disconnect_when_done_(disconnect_when_done) {
int flag = 0;
if (!strcmp(mode, "r")) { if (!strcmp(mode, "r")) {
flag = O_RDONLY; flag = O_RDONLY;
} else if (!strcmp(mode, "w")) { } else if (!strcmp(mode, "w")) {
@ -35,6 +40,9 @@ class HDFSStream : public utils::ISeekStream {
} }
virtual ~HDFSStream(void) { virtual ~HDFSStream(void) {
this->Close(); this->Close();
if (disconnect_when_done_) {
utils::Check(hdfsDisconnect(fs_) == 0, "hdfsDisconnect error");
}
} }
virtual size_t Read(void *ptr, size_t size) { virtual size_t Read(void *ptr, size_t size) {
tSize nread = hdfsRead(fs_, fp_, ptr, size); tSize nread = hdfsRead(fs_, fp_, ptr, size);
@ -86,52 +94,69 @@ class HDFSStream : public utils::ISeekStream {
} }
} }
inline static std::string GetNameNode(void) {
const char *nn = getenv("rabit_hdfs_namenode");
if (nn == NULL) {
return std::string("default");
} else {
return std::string(nn);
}
}
private: private:
hdfsFS fs_; hdfsFS fs_;
hdfsFile fp_; hdfsFile fp_;
bool at_end_; bool at_end_;
bool disconnect_when_done_;
}; };
/*! \brief line split from normal file system */ /*! \brief line split from normal file system */
class HDFSSplit : public LineSplitBase { class HDFSProvider : public LineSplitter::IFileProvider {
public: public:
explicit HDFSSplit(const char *uri, unsigned rank, unsigned nsplit) { explicit HDFSProvider(const char *uri) {
fs_ = hdfsConnect("default", 0); fs_ = hdfsConnect(HDFSStream::GetNameNode().c_str(), 0);
utils::Check(fs_ != NULL, "error when connecting to default HDFS");
std::vector<std::string> paths; std::vector<std::string> paths;
LineSplitBase::SplitNames(&paths, uri, "#"); LineSplitter::SplitNames(&paths, uri, "#");
// get the files // get the files
std::vector<size_t> fsize;
for (size_t i = 0; i < paths.size(); ++i) { for (size_t i = 0; i < paths.size(); ++i) {
hdfsFileInfo *info = hdfsGetPathInfo(fs_, paths[i].c_str()); hdfsFileInfo *info = hdfsGetPathInfo(fs_, paths[i].c_str());
utils::Check(info != NULL, "path %s do not exist", paths[i].c_str());
if (info->mKind == 'D') { if (info->mKind == 'D') {
int nentry; int nentry;
hdfsFileInfo *files = hdfsListDirectory(fs_, info->mName, &nentry); hdfsFileInfo *files = hdfsListDirectory(fs_, info->mName, &nentry);
utils::Check(files != NULL, "error when ListDirectory %s", info->mName);
for (int i = 0; i < nentry; ++i) { for (int i = 0; i < nentry; ++i) {
if (files[i].mKind == 'F') { if (files[i].mKind == 'F' && files[i].mSize != 0) {
fsize.push_back(files[i].mSize); fsize_.push_back(files[i].mSize);
fnames_.push_back(std::string(files[i].mName)); fnames_.push_back(std::string(files[i].mName));
} }
} }
hdfsFreeFileInfo(files, nentry); hdfsFreeFileInfo(files, nentry);
} else { } else {
fsize.push_back(info->mSize); if (info->mSize != 0) {
fnames_.push_back(std::string(info->mName)); fsize_.push_back(info->mSize);
fnames_.push_back(std::string(info->mName));
}
} }
hdfsFreeFileInfo(info, 1); hdfsFreeFileInfo(info, 1);
} }
LineSplitBase::Init(fsize, rank, nsplit);
} }
virtual ~HDFSSplit(void) {} virtual ~HDFSProvider(void) {
utils::Check(hdfsDisconnect(fs_) == 0, "hdfsDisconnect error");
protected: }
virtual utils::ISeekStream *GetFile(size_t file_index) { virtual const std::vector<size_t> &FileSize(void) const {
return fsize_;
}
virtual ISeekStream *Open(size_t file_index) {
utils::Assert(file_index < fnames_.size(), "file index exceed bound"); utils::Assert(file_index < fnames_.size(), "file index exceed bound");
return new HDFSStream(fs_, fnames_[file_index].c_str(), "r"); return new HDFSStream(fs_, fnames_[file_index].c_str(), "r", false);
} }
private: private:
// hdfs handle // hdfs handle
hdfsFS fs_; hdfsFS fs_;
// file sizes
std::vector<size_t> fsize_;
// file names // file names
std::vector<std::string> fnames_; std::vector<std::string> fnames_;
}; };

View File

@ -30,16 +30,16 @@ inline InputSplit *CreateInputSplit(const char *uri,
return new SingleFileSplit(uri); return new SingleFileSplit(uri);
} }
if (!strncmp(uri, "file://", 7)) { if (!strncmp(uri, "file://", 7)) {
return new FileSplit(uri, part, nsplit); return new LineSplitter(new FileProvider(uri), part, nsplit);
} }
if (!strncmp(uri, "hdfs://", 7)) { if (!strncmp(uri, "hdfs://", 7)) {
#if RABIT_USE_HDFS #if RABIT_USE_HDFS
return new HDFSSplit(uri, part, nsplit); return new LineSplitter(new HDFSProvider(uri), part, nsplit);
#else #else
utils::Error("Please compile with RABIT_USE_HDFS=1"); utils::Error("Please compile with RABIT_USE_HDFS=1");
#endif #endif
} }
return new FileSplit(uri, part, nsplit); return new LineSplitter(new FileProvider(uri), part, nsplit);
} }
/*! /*!
* \brief create an stream, the stream must be able to close * \brief create an stream, the stream must be able to close
@ -55,7 +55,8 @@ inline IStream *CreateStream(const char *uri, const char *mode) {
} }
if (!strncmp(uri, "hdfs://", 7)) { if (!strncmp(uri, "hdfs://", 7)) {
#if RABIT_USE_HDFS #if RABIT_USE_HDFS
return new HDFSStream(hdfsConnect("default", 0), uri, mode); return new HDFSStream(hdfsConnect(HDFSStream::GetNameNode().c_str(), 0),
uri, mode, true);
#else #else
utils::Error("Please compile with RABIT_USE_HDFS=1"); utils::Error("Please compile with RABIT_USE_HDFS=1");
#endif #endif

View File

@ -19,6 +19,7 @@ namespace rabit {
* \brief namespace to handle input split and filesystem interfacing * \brief namespace to handle input split and filesystem interfacing
*/ */
namespace io { namespace io {
/*! \brief reused ISeekStream's definition */
typedef utils::ISeekStream ISeekStream; typedef utils::ISeekStream ISeekStream;
/*! /*!
* \brief user facing input split helper, * \brief user facing input split helper,

View File

@ -15,11 +15,42 @@
namespace rabit { namespace rabit {
namespace io { namespace io {
class LineSplitBase : public InputSplit {
/*! \brief class that split the files by line */
class LineSplitter : public InputSplit {
public: public:
virtual ~LineSplitBase() { class IFileProvider {
if (fs_ != NULL) delete fs_; public:
/*!
* \brief get the seek stream of given file_index
* \return the corresponding seek stream at head of the stream
* the seek stream's resource can be freed by calling delete
*/
virtual ISeekStream *Open(size_t file_index) = 0;
/*!
* \return const reference to size of each files
*/
virtual const std::vector<size_t> &FileSize(void) const = 0;
// virtual destructor
virtual ~IFileProvider() {}
};
// constructor
explicit LineSplitter(IFileProvider *provider,
unsigned rank,
unsigned nsplit)
: provider_(provider), fs_(NULL),
reader_(kBufferSize) {
this->Init(provider_->FileSize(), rank, nsplit);
} }
// destructor
virtual ~LineSplitter() {
if (fs_ != NULL) {
delete fs_; fs_ = NULL;
}
// delete provider after destructing the streams
delete provider_;
}
// get next line
virtual bool NextLine(std::string *out_data) { virtual bool NextLine(std::string *out_data) {
if (file_ptr_ >= file_ptr_end_ && if (file_ptr_ >= file_ptr_end_ &&
offset_curr_ >= offset_end_) return false; offset_curr_ >= offset_end_) return false;
@ -29,15 +60,15 @@ class LineSplitBase : public InputSplit {
if (reader_.AtEnd()) { if (reader_.AtEnd()) {
if (out_data->length() != 0) return true; if (out_data->length() != 0) return true;
file_ptr_ += 1; file_ptr_ += 1;
if (offset_curr_ >= offset_end_) return false;
if (offset_curr_ != file_offset_[file_ptr_]) { if (offset_curr_ != file_offset_[file_ptr_]) {
utils::Error("warning:std::FILE size not calculated correctly\n"); utils::Error("warning: FILE size not calculated correctly\n");
offset_curr_ = file_offset_[file_ptr_]; offset_curr_ = file_offset_[file_ptr_];
} }
if (offset_curr_ >= offset_end_) return false;
utils::Assert(file_ptr_ + 1 < file_offset_.size(), utils::Assert(file_ptr_ + 1 < file_offset_.size(),
"boundary check"); "boundary check");
delete fs_; delete fs_;
fs_ = this->GetFile(file_ptr_); fs_ = provider_->Open(file_ptr_);
reader_.set_stream(fs_); reader_.set_stream(fs_);
} else { } else {
++offset_curr_; ++offset_curr_;
@ -51,15 +82,27 @@ class LineSplitBase : public InputSplit {
} }
} }
} }
/*!
protected: * \brief split names given
// constructor * \param out_fname output std::FILE names
LineSplitBase(void) * \param uri_ the iput uri std::FILE
: fs_(NULL), reader_(kBufferSize) { * \param dlm deliminetr
*/
inline static void SplitNames(std::vector<std::string> *out_fname,
const char *uri_,
const char *dlm) {
std::string uri = uri_;
char *p = std::strtok(BeginPtr(uri), dlm);
while (p != NULL) {
out_fname->push_back(std::string(p));
p = std::strtok(NULL, dlm);
}
} }
private:
/*! /*!
* \brief initialize the line spliter, * \brief initialize the line spliter,
* \param file_size, size of each std::FILEs * \param file_size, size of each files
* \param rank the current rank of the data * \param rank the current rank of the data
* \param nsplit number of split we will divide the data into * \param nsplit number of split we will divide the data into
*/ */
@ -82,7 +125,7 @@ class LineSplitBase : public InputSplit {
file_ptr_end_ = std::upper_bound(file_offset_.begin(), file_ptr_end_ = std::upper_bound(file_offset_.begin(),
file_offset_.end(), file_offset_.end(),
offset_end_) - file_offset_.begin() - 1; offset_end_) - file_offset_.begin() - 1;
fs_ = GetFile(file_ptr_); fs_ = provider_->Open(file_ptr_);
reader_.set_stream(fs_); reader_.set_stream(fs_);
// try to set the starting position correctly // try to set the starting position correctly
if (file_offset_[file_ptr_] != offset_begin_) { if (file_offset_[file_ptr_] != offset_begin_) {
@ -94,33 +137,15 @@ class LineSplitBase : public InputSplit {
} }
} }
} }
/*!
* \brief get the seek stream of given file_index
* \return the corresponding seek stream at head of std::FILE
*/
virtual utils::ISeekStream *GetFile(size_t file_index) = 0;
/*!
* \brief split names given
* \param out_fname output std::FILE names
* \param uri_ the iput uri std::FILE
* \param dlm deliminetr
*/
inline static void SplitNames(std::vector<std::string> *out_fname,
const char *uri_,
const char *dlm) {
std::string uri = uri_;
char *p = std::strtok(BeginPtr(uri), dlm);
while (p != NULL) {
out_fname->push_back(std::string(p));
p = std::strtok(NULL, dlm);
}
}
private: private:
/*! \brief FileProvider */
IFileProvider *provider_;
/*! \brief current input stream */ /*! \brief current input stream */
utils::ISeekStream *fs_; utils::ISeekStream *fs_;
/*! \brief std::FILE pointer of which std::FILE to read on */ /*! \brief file pointer of which file to read on */
size_t file_ptr_; size_t file_ptr_;
/*! \brief std::FILE pointer where the end of std::FILE lies */ /*! \brief file pointer where the end of file lies */
size_t file_ptr_end_; size_t file_ptr_end_;
/*! \brief get the current offset */ /*! \brief get the current offset */
size_t offset_curr_; size_t offset_curr_;
@ -128,7 +153,7 @@ class LineSplitBase : public InputSplit {
size_t offset_begin_; size_t offset_begin_;
/*! \brief end of the offset */ /*! \brief end of the offset */
size_t offset_end_; size_t offset_end_;
/*! \brief byte-offset of each std::FILE */ /*! \brief byte-offset of each file */
std::vector<size_t> file_offset_; std::vector<size_t> file_offset_;
/*! \brief buffer reader */ /*! \brief buffer reader */
StreamBufferReader reader_; StreamBufferReader reader_;

View File

@ -1,4 +1,10 @@
# specify tensor path ifneq ("$(wildcard ../config.mk)","")
config = ../config.mk
else
config = ../make/config.mk
endif
include $(config)
BIN = linear.rabit BIN = linear.rabit
MOCKBIN= linear.mock MOCKBIN= linear.mock
MPIBIN = MPIBIN =
@ -6,10 +12,10 @@ MPIBIN =
OBJ = linear.o OBJ = linear.o
# common build script for programs # common build script for programs
include ../make/config.mk
include ../make/common.mk include ../make/common.mk
CFLAGS+=-fopenmp CFLAGS+=-fopenmp
linear.o: linear.cc ../../src/*.h linear.h ../solver/*.h linear.o: linear.cc ../../src/*.h linear.h ../solver/*.h
# dependenies here # dependenies here
linear.rabit: linear.o lib linear.rabit: linear.o lib
linear.mock: linear.o lib linear.mock: linear.o lib

View File

@ -206,21 +206,22 @@ int main(int argc, char *argv[]) {
rabit::Finalize(); rabit::Finalize();
return 0; return 0;
} }
rabit::linear::LinearObjFunction linear; rabit::linear::LinearObjFunction *linear = new rabit::linear::LinearObjFunction();
if (!strcmp(argv[1], "stdin")) { if (!strcmp(argv[1], "stdin")) {
linear.LoadData(argv[1]); linear->LoadData(argv[1]);
rabit::Init(argc, argv); rabit::Init(argc, argv);
} else { } else {
rabit::Init(argc, argv); rabit::Init(argc, argv);
linear.LoadData(argv[1]); linear->LoadData(argv[1]);
} }
for (int i = 2; i < argc; ++i) { for (int i = 2; i < argc; ++i) {
char name[256], val[256]; char name[256], val[256];
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) { if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
linear.SetParam(name, val); linear->SetParam(name, val);
} }
} }
linear.Run(); linear->Run();
delete linear;
rabit::Finalize(); rabit::Finalize();
return 0; return 0;
} }

View File

@ -26,10 +26,11 @@ struct LinearModel {
int reserved[16]; int reserved[16];
// constructor // constructor
ModelParam(void) { ModelParam(void) {
memset(this, 0, sizeof(ModelParam));
base_score = 0.5f; base_score = 0.5f;
num_feature = 0; num_feature = 0;
loss_type = 1; loss_type = 1;
std::memset(reserved, 0, sizeof(reserved)); num_feature = 0;
} }
// initialize base score // initialize base score
inline void InitBaseScore(void) { inline void InitBaseScore(void) {
@ -119,7 +120,7 @@ struct LinearModel {
} }
fi.Read(weight, sizeof(float) * (param.num_feature + 1)); fi.Read(weight, sizeof(float) * (param.num_feature + 1));
} }
inline void Save(rabit::IStream &fo, const float *wptr = NULL) const { inline void Save(rabit::IStream &fo, const float *wptr = NULL) {
fo.Write(&param, sizeof(param)); fo.Write(&param, sizeof(param));
if (wptr == NULL) wptr = weight; if (wptr == NULL) wptr = weight;
fo.Write(wptr, sizeof(float) * (param.num_feature + 1)); fo.Write(wptr, sizeof(float) * (param.num_feature + 1));

View File

@ -6,12 +6,13 @@ then
fi fi
# put the local training file to HDFS # put the local training file to HDFS
hadoop fs -rm -r -f $2/data
hadoop fs -rm -r -f $2/mushroom.linear.model hadoop fs -rm -r -f $2/mushroom.linear.model
hadoop fs -mkdir $2/data hadoop fs -mkdir $2/data
hadoop fs -put ../data/agaricus.txt.train $2/data
# submit to hadoop # submit to hadoop
../../tracker/rabit_yarn.py -n $1 --vcores 1 linear.rabit hdfs://$2/data/agaricus.txt.train model_out=hdfs://$2/mushroom.linear.model "${*:3}" ../../tracker/rabit_yarn.py -n $1 --vcores 1 ./linear.rabit hdfs://$2/data/agaricus.txt.train model_out=hdfs://$2/mushroom.linear.model "${*:3}"
# get the final model file # get the final model file
hadoop fs -get $2/mushroom.linear.model ./linear.model hadoop fs -get $2/mushroom.linear.model ./linear.model

View File

@ -6,7 +6,7 @@
# #
# - copy this file to the root of rabit-learn folder # - copy this file to the root of rabit-learn folder
# - modify the configuration you want # - modify the configuration you want
# - type make or make -j n for parallel build # - type make or make -j n on each of the folder
#---------------------------------------------------- #----------------------------------------------------
# choice of compiler # choice of compiler

View File

@ -145,8 +145,9 @@ class LBFGSSolver {
if (silent == 0 && rabit::GetRank() == 0) { if (silent == 0 && rabit::GetRank() == 0) {
rabit::TrackerPrintf rabit::TrackerPrintf
("L-BFGS solver starts, num_dim=%lu, init_objval=%g, size_memory=%lu\n", ("L-BFGS solver starts, num_dim=%lu, init_objval=%g, size_memory=%lu, RAM-approx=%lu\n",
gstate.num_dim, gstate.init_objval, gstate.size_memory); gstate.num_dim, gstate.init_objval, gstate.size_memory,
gstate.MemCost() + hist.MemCost());
} }
} }
} }
@ -176,7 +177,7 @@ class LBFGSSolver {
// swap new weight // swap new weight
std::swap(g.weight, g.grad); std::swap(g.weight, g.grad);
// check stop condition // check stop condition
if (gstate.num_iteration > min_lbfgs_iter) { if (gstate.num_iteration > static_cast<size_t>(min_lbfgs_iter)) {
if (g.old_objval - g.new_objval < lbfgs_stop_tol * g.init_objval) { if (g.old_objval - g.new_objval < lbfgs_stop_tol * g.init_objval) {
return true; return true;
} }
@ -195,7 +196,7 @@ class LBFGSSolver {
/*! \brief run optimization */ /*! \brief run optimization */
virtual void Run(void) { virtual void Run(void) {
this->Init(); this->Init();
while (gstate.num_iteration < max_lbfgs_iter) { while (gstate.num_iteration < static_cast<size_t>(max_lbfgs_iter)) {
if (this->UpdateOneIter()) break; if (this->UpdateOneIter()) break;
} }
if (silent == 0 && rabit::GetRank() == 0) { if (silent == 0 && rabit::GetRank() == 0) {
@ -225,7 +226,7 @@ class LBFGSSolver {
const size_t num_dim = gstate.num_dim; const size_t num_dim = gstate.num_dim;
const DType *gsub = grad + range_begin_; const DType *gsub = grad + range_begin_;
const size_t nsub = range_end_ - range_begin_; const size_t nsub = range_end_ - range_begin_;
double vdot; double vdot = 0.0;
if (n != 0) { if (n != 0) {
// hist[m + n - 1] stores old gradient // hist[m + n - 1] stores old gradient
Minus(hist[m + n - 1], gsub, hist[m + n - 1], nsub); Minus(hist[m + n - 1], gsub, hist[m + n - 1], nsub);
@ -241,15 +242,19 @@ class LBFGSSolver {
idxset.push_back(std::make_pair(m + j, 2 * m)); idxset.push_back(std::make_pair(m + j, 2 * m));
idxset.push_back(std::make_pair(m + j, m + n - 1)); idxset.push_back(std::make_pair(m + j, m + n - 1));
} }
// calculate dot products // calculate dot products
std::vector<double> tmp(idxset.size()); std::vector<double> tmp(idxset.size());
for (size_t i = 0; i < tmp.size(); ++i) { for (size_t i = 0; i < tmp.size(); ++i) {
tmp[i] = hist.CalcDot(idxset[i].first, idxset[i].second); tmp[i] = hist.CalcDot(idxset[i].first, idxset[i].second);
} }
rabit::Allreduce<rabit::op::Sum>(BeginPtr(tmp), tmp.size()); rabit::Allreduce<rabit::op::Sum>(BeginPtr(tmp), tmp.size());
for (size_t i = 0; i < tmp.size(); ++i) { for (size_t i = 0; i < tmp.size(); ++i) {
gstate.DotBuf(idxset[i].first, idxset[i].second) = tmp[i]; gstate.DotBuf(idxset[i].first, idxset[i].second) = tmp[i];
} }
// BFGS steps, use vector-free update // BFGS steps, use vector-free update
// parameterize vector using basis in hist // parameterize vector using basis in hist
std::vector<double> alpha(n); std::vector<double> alpha(n);
@ -263,7 +268,7 @@ class LBFGSSolver {
} }
alpha[j] = vsum / gstate.DotBuf(j, m + j); alpha[j] = vsum / gstate.DotBuf(j, m + j);
delta[m + j] = delta[m + j] - alpha[j]; delta[m + j] = delta[m + j] - alpha[j];
} }
// scale // scale
double scale = gstate.DotBuf(n - 1, m + n - 1) / double scale = gstate.DotBuf(n - 1, m + n - 1) /
gstate.DotBuf(m + n - 1, m + n - 1); gstate.DotBuf(m + n - 1, m + n - 1);
@ -279,6 +284,7 @@ class LBFGSSolver {
double beta = vsum / gstate.DotBuf(j, m + j); double beta = vsum / gstate.DotBuf(j, m + j);
delta[j] = delta[j] + (alpha[j] - beta); delta[j] = delta[j] + (alpha[j] - beta);
} }
// set all to zero // set all to zero
std::fill(dir, dir + num_dim, 0.0f); std::fill(dir, dir + num_dim, 0.0f);
DType *dirsub = dir + range_begin_; DType *dirsub = dir + range_begin_;
@ -291,10 +297,11 @@ class LBFGSSolver {
} }
FixDirL1Sign(dirsub, hist[2 * m], nsub); FixDirL1Sign(dirsub, hist[2 * m], nsub);
vdot = -Dot(dirsub, hist[2 * m], nsub); vdot = -Dot(dirsub, hist[2 * m], nsub);
// allreduce to get full direction // allreduce to get full direction
rabit::Allreduce<rabit::op::Sum>(dir, num_dim); rabit::Allreduce<rabit::op::Sum>(dir, num_dim);
rabit::Allreduce<rabit::op::Sum>(&vdot, 1); rabit::Allreduce<rabit::op::Sum>(&vdot, 1);
} else { } else {
SetL1Dir(dir, grad, weight, num_dim); SetL1Dir(dir, grad, weight, num_dim);
vdot = -Dot(dir, dir, num_dim); vdot = -Dot(dir, dir, num_dim);
} }
@ -482,6 +489,7 @@ class LBFGSSolver {
num_iteration = 0; num_iteration = 0;
num_dim = 0; num_dim = 0;
old_objval = 0.0; old_objval = 0.0;
offset_ = 0;
} }
~GlobalState(void) { ~GlobalState(void) {
if (grad != NULL) { if (grad != NULL) {
@ -496,6 +504,10 @@ class LBFGSSolver {
data.resize(n * n, 0.0); data.resize(n * n, 0.0);
this->AllocSpace(); this->AllocSpace();
} }
// memory cost
inline size_t MemCost(void) const {
return sizeof(DType) * 3 * num_dim;
}
inline double &DotBuf(size_t i, size_t j) { inline double &DotBuf(size_t i, size_t j) {
if (i > j) std::swap(i, j); if (i > j) std::swap(i, j);
return data[MapIndex(i, offset_, size_memory) * (size_memory * 2 + 1) + return data[MapIndex(i, offset_, size_memory) * (size_memory * 2 + 1) +
@ -565,6 +577,10 @@ class LBFGSSolver {
size_t n = size_memory * 2 + 1; size_t n = size_memory * 2 + 1;
dptr_ = new DType[n * stride_]; dptr_ = new DType[n * stride_];
} }
// memory cost
inline size_t MemCost(void) const {
return sizeof(DType) * (size_memory_ * 2 + 1) * stride_;
}
// fetch element from rolling array // fetch element from rolling array
inline const DType *operator[](size_t i) const { inline const DType *operator[](size_t i) const {
return dptr_ + MapIndex(i, offset_, size_memory_) * stride_; return dptr_ + MapIndex(i, offset_, size_memory_) * stride_;

View File

@ -77,11 +77,15 @@ struct SparseMat {
feat_dim += 1; feat_dim += 1;
utils::Check(feat_dim < std::numeric_limits<index_t>::max(), utils::Check(feat_dim < std::numeric_limits<index_t>::max(),
"feature dimension exceed limit of index_t"\ "feature dimension exceed limit of index_t"\
"consider change the index_t to unsigned long"); "consider change the index_t to unsigned long");
} }
inline size_t NumRow(void) const { inline size_t NumRow(void) const {
return row_ptr.size() - 1; return row_ptr.size() - 1;
} }
// memory cost
inline size_t MemCost(void) const {
return data.size() * sizeof(Entry);
}
// maximum feature dimension // maximum feature dimension
size_t feat_dim; size_t feat_dim;
std::vector<size_t> row_ptr; std::vector<size_t> row_ptr;

View File

@ -26,6 +26,9 @@ AllreduceBase::AllreduceBase(void) {
world_size = -1; world_size = -1;
hadoop_mode = 0; hadoop_mode = 0;
version_number = 0; version_number = 0;
// 32 K items
reduce_ring_mincount = 32 << 10;
// tracker URL
task_id = "NULL"; task_id = "NULL";
err_link = NULL; err_link = NULL;
this->SetParam("rabit_reduce_buffer", "256MB"); this->SetParam("rabit_reduce_buffer", "256MB");
@ -33,7 +36,8 @@ AllreduceBase::AllreduceBase(void) {
env_vars.push_back("rabit_task_id"); env_vars.push_back("rabit_task_id");
env_vars.push_back("rabit_num_trial"); env_vars.push_back("rabit_num_trial");
env_vars.push_back("rabit_reduce_buffer"); env_vars.push_back("rabit_reduce_buffer");
env_vars.push_back("rabit_tracker_uri"); env_vars.push_back("rabit_reduce_ring_mincount");
env_vars.push_back("rabit_tracker_uri");
env_vars.push_back("rabit_tracker_port"); env_vars.push_back("rabit_tracker_port");
} }
@ -116,6 +120,27 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
tracker.SendStr(msg); tracker.SendStr(msg);
tracker.Close(); tracker.Close();
} }
// util to parse data with unit suffix
inline size_t ParseUnit(const char *name, const char *val) {
char unit;
uint64_t amount;
int n = sscanf(val, "%lu%c", &amount, &unit);
if (n == 2) {
switch (unit) {
case 'B': return amount;
case 'K': return amount << 10UL;
case 'M': return amount << 20UL;
case 'G': return amount << 30UL;
default: utils::Error("invalid format for %s", name); return 0;
}
} else if (n == 1) {
return amount;
} else {
utils::Error("invalid format for %s," \
"shhould be {integer}{unit}, unit can be {B, KB, MB, GB}", name);
return 0;
}
}
/*! /*!
* \brief set parameters to the engine * \brief set parameters to the engine
* \param name parameter name * \param name parameter name
@ -127,21 +152,11 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
if (!strcmp(name, "rabit_task_id")) task_id = val; if (!strcmp(name, "rabit_task_id")) task_id = val;
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val); if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val); if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val);
if (!strcmp(name, "rabit_reduce_ring_mincount")) {
reduce_ring_mincount = ParseUnit(name, val);
}
if (!strcmp(name, "rabit_reduce_buffer")) { if (!strcmp(name, "rabit_reduce_buffer")) {
char unit; reduce_buffer_size = (ParseUnit(name, val) + 7) >> 3;
uint64_t amount;
if (sscanf(val, "%lu%c", &amount, &unit) == 2) {
switch (unit) {
case 'B': reduce_buffer_size = (amount + 7)/ 8; break;
case 'K': reduce_buffer_size = amount << 7UL; break;
case 'M': reduce_buffer_size = amount << 17UL; break;
case 'G': reduce_buffer_size = amount << 27UL; break;
default: utils::Error("invalid format for reduce buffer");
}
} else {
utils::Error("invalid format for reduce_buffer,"\
"shhould be {integer}{unit}, unit can be {B, KB, MB, GB}");
}
} }
} }
/*! /*!
@ -341,6 +356,28 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
size_t type_nbytes, size_t type_nbytes,
size_t count, size_t count,
ReduceFunction reducer) { ReduceFunction reducer) {
if (count > reduce_ring_mincount) {
return this->TryAllreduceRing(sendrecvbuf_, type_nbytes, count, reducer);
} else {
return this->TryAllreduceTree(sendrecvbuf_, type_nbytes, count, reducer);
}
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf,
* this function implements tree-shape reduction
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
AllreduceBase::ReturnType
AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
RefLinkVector &links = tree_links; RefLinkVector &links = tree_links;
if (links.size() == 0 || count == 0) return kSuccess; if (links.size() == 0 || count == 0) return kSuccess;
// total size of message // total size of message
@ -411,7 +448,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
// read data from childs // read data from childs
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (i != parent_index && selecter.CheckRead(links[i].sock)) { if (i != parent_index && selecter.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToRingBuffer(size_up_out); ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size);
if (ret != kSuccess) { if (ret != kSuccess) {
return ReportError(&links[i], ret); return ReportError(&links[i], ret);
} }
@ -599,5 +636,217 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
} }
return kSuccess; return kSuccess;
} }
/*!
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
*/
AllreduceBase::ReturnType
AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice) {
// read from next link and send to prev one
LinkRecord &prev = *ring_prev, &next = *ring_next;
// need to reply on special rank structure
utils::Assert(next.rank == (rank + 1) % world_size &&
rank == (prev.rank + 1) % world_size,
"need to assume rank structure");
// send recv buffer
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
const size_t stop_read = total_size + slice_begin;
const size_t stop_write = total_size + slice_begin - size_prev_slice;
size_t write_ptr = slice_begin;
size_t read_ptr = slice_end;
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
if (read_ptr != stop_read) {
selecter.WatchRead(next.sock);
finished = false;
}
if (write_ptr != stop_write) {
if (write_ptr < read_ptr) {
selecter.WatchWrite(prev.sock);
}
finished = false;
}
if (finished) break;
selecter.Select();
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
size_t size = stop_read - read_ptr;
size_t start = read_ptr % total_size;
if (start + size > total_size) {
size = total_size - start;
}
ssize_t len = next.sock.Recv(sendrecvbuf + start, size);
if (len != -1) {
read_ptr += static_cast<size_t>(len);
} else {
ReturnType ret = Errno2Return(errno);
if (ret != kSuccess) return ReportError(&next, ret);
}
}
if (write_ptr < read_ptr && write_ptr != stop_write) {
size_t size = std::min(read_ptr, stop_write) - write_ptr;
size_t start = write_ptr % total_size;
if (start + size > total_size) {
size = total_size - start;
}
ssize_t len = prev.sock.Send(sendrecvbuf + start, size);
if (len != -1) {
write_ptr += static_cast<size_t>(len);
} else {
ReturnType ret = Errno2Return(errno);
if (ret != kSuccess) return ReportError(&prev, ret);
}
}
}
return kSuccess;
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail,
* and will return the cause of failure
*
* Ring-based algorithm
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType, TryAllreduce
*/
AllreduceBase::ReturnType
AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
// read from next link and send to prev one
LinkRecord &prev = *ring_prev, &next = *ring_next;
// need to reply on special rank structure
utils::Assert(next.rank == (rank + 1) % world_size &&
rank == (prev.rank + 1) % world_size,
"need to assume rank structure");
// total size of message
const size_t total_size = type_nbytes * count;
size_t n = static_cast<size_t>(world_size);
size_t step = (count + n - 1) / n;
size_t r = static_cast<size_t>(next.rank);
size_t write_ptr = std::min(r * step, count) * type_nbytes;
size_t read_ptr = std::min((r + 1) * step, count) * type_nbytes;
size_t reduce_ptr = read_ptr;
// send recv buffer
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
// position to stop reading
const size_t stop_read = total_size + write_ptr;
// position to stop writing
size_t stop_write = total_size + std::min(rank * step, count) * type_nbytes;
if (stop_write > stop_read) {
stop_write -= total_size;
utils::Assert(write_ptr <= stop_write, "write ptr boundary check");
}
// use ring buffer in next position
next.InitBuffer(type_nbytes, step, reduce_buffer_size);
// set size_read to read pointer for ring buffer to work properly
next.size_read = read_ptr;
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
if (read_ptr != stop_read) {
selecter.WatchRead(next.sock);
finished = false;
}
if (write_ptr != stop_write) {
if (write_ptr < reduce_ptr) {
selecter.WatchWrite(prev.sock);
}
finished = false;
}
if (finished) break;
selecter.Select();
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
if (ret != kSuccess) {
return ReportError(&next, ret);
}
// sync the rate
read_ptr = next.size_read;
utils::Assert(read_ptr <= stop_read, "[%d] read_ptr boundary check", rank);
const size_t buffer_size = next.buffer_size;
size_t max_reduce = (read_ptr / type_nbytes) * type_nbytes;
while (reduce_ptr < max_reduce) {
size_t bstart = reduce_ptr % buffer_size;
size_t nread = std::min(buffer_size - bstart,
max_reduce - reduce_ptr);
size_t rstart = reduce_ptr % total_size;
nread = std::min(nread, total_size - rstart);
reducer(next.buffer_head + bstart,
sendrecvbuf + rstart,
static_cast<int>(nread / type_nbytes),
MPI::Datatype(type_nbytes));
reduce_ptr += nread;
}
}
if (write_ptr < reduce_ptr && write_ptr != stop_write) {
size_t size = std::min(reduce_ptr, stop_write) - write_ptr;
size_t start = write_ptr % total_size;
if (start + size > total_size) {
size = total_size - start;
}
ssize_t len = prev.sock.Send(sendrecvbuf + start, size);
if (len != -1) {
write_ptr += static_cast<size_t>(len);
} else {
ReturnType ret = Errno2Return(errno);
if (ret != kSuccess) return ReportError(&prev, ret);
}
}
}
return kSuccess;
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
AllreduceBase::ReturnType
AllreduceBase::TryAllreduceRing(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
ReturnType ret = TryReduceScatterRing(sendrecvbuf_, type_nbytes, count, reducer);
if (ret != kSuccess) return ret;
size_t n = static_cast<size_t>(world_size);
size_t step = (count + n - 1) / n;
size_t begin = std::min(rank * step, count) * type_nbytes;
size_t end = std::min((rank + 1) * step, count) * type_nbytes;
// previous rank
int prank = ring_prev->rank;
// get rank of previous
return TryAllgatherRing
(sendrecvbuf_, type_nbytes * count,
begin, end,
(std::min((prank + 1) * step, count) -
std::min(prank * step, count)) * type_nbytes);
}
} // namespace engine } // namespace engine
} // namespace rabit } // namespace rabit

View File

@ -278,15 +278,19 @@ class AllreduceBase : public IEngine {
* \brief read data into ring-buffer, with care not to existing useful override data * \brief read data into ring-buffer, with care not to existing useful override data
* position after protect_start * position after protect_start
* \param protect_start all data start from protect_start is still needed in buffer * \param protect_start all data start from protect_start is still needed in buffer
* read shall not override this * read shall not override this
* \param max_size_read maximum logical amount we can read, size_read cannot exceed this value
* \return the type of reading * \return the type of reading
*/ */
inline ReturnType ReadToRingBuffer(size_t protect_start) { inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) {
utils::Assert(buffer_head != NULL, "ReadToRingBuffer: buffer not allocated"); utils::Assert(buffer_head != NULL, "ReadToRingBuffer: buffer not allocated");
utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check");
size_t ngap = size_read - protect_start; size_t ngap = size_read - protect_start;
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check"); utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
size_t offset = size_read % buffer_size; size_t offset = size_read % buffer_size;
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset); size_t nmax = max_size_read - size_read;
nmax = std::min(nmax, buffer_size - ngap);
nmax = std::min(nmax, buffer_size - offset);
if (nmax == 0) return kSuccess; if (nmax == 0) return kSuccess;
ssize_t len = sock.Recv(buffer_head + offset, nmax); ssize_t len = sock.Recv(buffer_head + offset, nmax);
// length equals 0, remote disconnected // length equals 0, remote disconnected
@ -380,13 +384,79 @@ class AllreduceBase : public IEngine {
ReduceFunction reducer); ReduceFunction reducer);
/*! /*!
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure * \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
* \param sendrecvbuf_ buffer for both sending and recving data * \param sendrecvbuf_ buffer for both sending and receiving data
* \param size the size of the data to be broadcasted * \param size the size of the data to be broadcasted
* \param root the root worker id to broadcast the data * \param root the root worker id to broadcast the data
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType * \sa ReturnType
*/ */
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root); ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
/*!
* \brief perform in-place allreduce, on sendrecvbuf,
* this function implements tree-shape reduction
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryAllreduceTree(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer);
/*!
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
size_t slice_begin, size_t slice_end,
size_t size_prev_slice);
/*!
* \brief perform in-place allreduce, reduce on the sendrecvbuf,
*
* after the function, node k get k-th segment of the reduction result
* the k-th segment is defined by [k * step, min((k + 1) * step,count) )
* where step = ceil(count / world_size)
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType, TryAllreduce
*/
ReturnType TryReduceScatterRing(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer);
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* use a ring based algorithm, reduce-scatter + allgather
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryAllreduceRing(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer);
/*! /*!
* \brief function used to report error when a link goes wrong * \brief function used to report error when a link goes wrong
* \param link the pointer to the link who causes the error * \param link the pointer to the link who causes the error
@ -432,6 +502,10 @@ class AllreduceBase : public IEngine {
int slave_port, nport_trial; int slave_port, nport_trial;
// reduce buffer size // reduce buffer size
size_t reduce_buffer_size; size_t reduce_buffer_size;
// reduction method
int reduce_method;
// mininum count of cells to use ring based method
size_t reduce_ring_mincount;
// current rank // current rank
int rank; int rank;
// world size // world size

View File

@ -81,18 +81,18 @@ class AllreduceMock : public AllreduceRobust {
ComboSerializer com(global_model, local_model); ComboSerializer com(global_model, local_model);
AllreduceRobust::CheckPoint(&dum, &com); AllreduceRobust::CheckPoint(&dum, &com);
} }
tsum_allreduce = 0.0;
time_checkpoint = utils::GetTime(); time_checkpoint = utils::GetTime();
double tcost = utils::GetTime() - tstart; double tcost = utils::GetTime() - tstart;
if (report_stats != 0 && rank == 0) { if (report_stats != 0 && rank == 0) {
std::stringstream ss; std::stringstream ss;
ss << "[v" << version_number << "] global_size=" << global_checkpoint.length() ss << "[v" << version_number << "] global_size=" << global_checkpoint.length()
<< "local_size=" << local_chkpt[local_chkpt_version].length() << ",local_size=" << (local_chkpt[0].length() + local_chkpt[1].length())
<< "check_tcost="<< tcost <<" sec," << ",check_tcost="<< tcost <<" sec"
<< "allreduce_tcost=" << tsum_allreduce << " sec," << ",allreduce_tcost=" << tsum_allreduce << " sec"
<< "between_chpt=" << tbet_chkpt << "sec\n"; << ",between_chpt=" << tbet_chkpt << "sec\n";
this->TrackerPrint(ss.str()); this->TrackerPrint(ss.str());
} }
tsum_allreduce = 0.0;
} }
virtual void LazyCheckPoint(const ISerializable *global_model) { virtual void LazyCheckPoint(const ISerializable *global_model) {

View File

@ -677,7 +677,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
if (req_in[i]) min_write = std::min(links[i].size_write, min_write); if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
} }
utils::Assert(min_write <= links[pid].size_read, "boundary check"); utils::Assert(min_write <= links[pid].size_read, "boundary check");
ReturnType ret = links[pid].ReadToRingBuffer(min_write); ReturnType ret = links[pid].ReadToRingBuffer(min_write, size);
if (ret != kSuccess) { if (ret != kSuccess) {
return ReportError(&links[pid], ret); return ReportError(&links[pid], ret);
} }

View File

@ -287,7 +287,6 @@ class AllreduceRobust : public AllreduceBase {
if (seqno_.size() == 0) return -1; if (seqno_.size() == 0) return -1;
return seqno_.back(); return seqno_.back();
} }
private: private:
// sequence number of each // sequence number of each
std::vector<int> seqno_; std::vector<int> seqno_;

View File

@ -1,7 +1,7 @@
export CC = gcc export CC = gcc
export CXX = g++ export CXX = g++
export MPICXX = mpicxx export MPICXX = mpicxx
export LDFLAGS= -pthread -lm -lrt -L../lib export LDFLAGS= -L../lib -pthread -lm -lrt
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../include -std=c++11 export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../include -std=c++11
# specify tensor path # specify tensor path
@ -29,7 +29,7 @@ local_recover: local_recover.o $(RABIT_OBJ)
lazy_recover: lazy_recover.o $(RABIT_OBJ) lazy_recover: lazy_recover.o $(RABIT_OBJ)
$(BIN) : $(BIN) :
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit_mock $(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) -lrabit_mock $(LDFLAGS)
$(OBJ) : $(OBJ) :
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) ) $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )

View File

@ -23,4 +23,7 @@ lazy_recover_10_10k_die_hard:
../tracker/rabit_demo.py -n 10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0 ../tracker/rabit_demo.py -n 10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0
lazy_recover_10_10k_die_same: lazy_recover_10_10k_die_same:
../tracker/rabit_demo.py -n 10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 ../tracker/rabit_demo.py -n 10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0
ringallreduce_10_10k:
../tracker/rabit_demo.py -v 1 -n 10 model_recover 100 rabit_reduce_ring_mincount=10

View File

@ -9,4 +9,4 @@ the example guidelines are in the script themselfs
* Yarn (Hadoop): [rabit_yarn.py](rabit_yarn.py) * Yarn (Hadoop): [rabit_yarn.py](rabit_yarn.py)
- It is also possible to submit via hadoop streaming with rabit_hadoop_streaming.py - It is also possible to submit via hadoop streaming with rabit_hadoop_streaming.py
- However, it is higly recommended to use rabit_yarn.py because this will allocate resources more precisely and fits machine learning scenarios - However, it is higly recommended to use rabit_yarn.py because this will allocate resources more precisely and fits machine learning scenarios
* Sun Grid engine: [rabit_sge.py](rabit_sge.py)

View File

@ -1,7 +1,6 @@
#!/usr/bin/python #!/usr/bin/env python
""" """
This is the demo submission script of rabit, it is created to This is the demo submission script of rabit for submitting jobs in local machine
submit rabit jobs using hadoop streaming
""" """
import argparse import argparse
import sys import sys
@ -43,7 +42,7 @@ def exec_cmd(cmd, taskid, worker_env):
if cmd[0].find('/') == -1 and os.path.exists(cmd[0]) and os.name != 'nt': if cmd[0].find('/') == -1 and os.path.exists(cmd[0]) and os.name != 'nt':
cmd[0] = './' + cmd[0] cmd[0] = './' + cmd[0]
cmd = ' '.join(cmd) cmd = ' '.join(cmd)
env = {} env = os.environ.copy()
for k, v in worker_env.items(): for k, v in worker_env.items():
env[k] = str(v) env[k] = str(v)
env['rabit_task_id'] = str(taskid) env['rabit_task_id'] = str(taskid)

View File

@ -1,4 +1,4 @@
#!/usr/bin/python #!/usr/bin/env python
""" """
Deprecated Deprecated

View File

@ -1,7 +1,6 @@
#!/usr/bin/python #!/usr/bin/env python
""" """
This is the demo submission script of rabit, it is created to Submission script to submit rabit jobs using MPI
submit rabit jobs using hadoop streaming
""" """
import argparse import argparse
import sys import sys

View File

@ -0,0 +1,69 @@
#!/usr/bin/env python
"""
Submit rabit jobs to Sun Grid Engine
"""
import argparse
import sys
import os
import subprocess
import rabit_tracker as tracker
parser = argparse.ArgumentParser(description='Rabit script to submit rabit job using MPI')
parser.add_argument('-n', '--nworker', required=True, type=int,
help = 'number of worker proccess to be launched')
parser.add_argument('-q', '--queue', default='default', type=str,
help = 'the queue we want to submit the job to')
parser.add_argument('-hip', '--host_ip', default='auto', type=str,
help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine')
parser.add_argument('--vcores', default = 1, type=int,
help = 'number of vcpores to request in each mapper, set it if each rabit job is multi-threaded')
parser.add_argument('--jobname', default='auto', help = 'customize jobname in tracker')
parser.add_argument('--logdir', default='auto', help = 'customize the directory to place the logs')
parser.add_argument('-v', '--verbose', default=0, choices=[0, 1], type=int,
help = 'print more messages into the console')
parser.add_argument('command', nargs='+',
help = 'command for rabit program')
args = parser.parse_args()
if args.jobname == 'auto':
args.jobname = ('rabit%d.' % args.nworker) + args.command[0].split('/')[-1];
if args.logdir == 'auto':
args.logdir = args.jobname + '.log'
if os.path.exists(args.logdir):
if not os.path.isdir(args.logdir):
raise RuntimeError('specified logdir %s is a file instead of directory' % args.logdir)
else:
os.mkdir(args.logdir)
runscript = '%s/runrabit.sh' % args.logdir
fo = open(runscript, 'w')
fo.write('\"$@\"\n')
fo.close()
#
# submission script using MPI
#
def sge_submit(nslave, worker_args, worker_envs):
"""
customized submit script, that submit nslave jobs, each must contain args as parameter
note this can be a lambda function containing additional parameters in input
Parameters
nslave number of slave process to start up
args arguments to launch each job
this usually includes the parameters of master_uri and parameters passed into submit
"""
env_arg = ','.join(['%s=\"%s\"' % (k, str(v)) for k, v in worker_envs.items()])
cmd = 'qsub -cwd -t 1-%d -S /bin/bash' % nslave
if args.queue != 'default':
cmd += '-q %s' % args.queue
cmd += ' -N %s ' % args.jobname
cmd += ' -e %s -o %s' % (args.logdir, args.logdir)
cmd += ' -pe orte %d' % (args.vcores)
cmd += ' -v %s,PATH=${PATH}:.' % env_arg
cmd += ' %s %s' % (runscript, ' '.join(args.command + worker_args))
print cmd
subprocess.check_call(cmd, shell = True)
print 'Waiting for the jobs to get up...'
# call submit, with nslave, the commands to run each job and submit function
tracker.submit(args.nworker, [], fun_submit = sge_submit, verbose = args.verbose)

View File

@ -13,6 +13,7 @@ import socket
import struct import struct
import subprocess import subprocess
import random import random
import time
from threading import Thread from threading import Thread
""" """
@ -188,6 +189,7 @@ class Tracker:
vlst.reverse() vlst.reverse()
rlst += vlst rlst += vlst
return rlst return rlst
def get_ring(self, tree_map, parent_map): def get_ring(self, tree_map, parent_map):
""" """
get a ring connection used to recover local data get a ring connection used to recover local data
@ -202,14 +204,44 @@ class Tracker:
rnext = (r + 1) % nslave rnext = (r + 1) % nslave
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext]) ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
return ring_map return ring_map
def get_link_map(self, nslave):
"""
get the link map, this is a bit hacky, call for better algorithm
to place similar nodes together
"""
tree_map, parent_map = self.get_tree(nslave)
ring_map = self.get_ring(tree_map, parent_map)
rmap = {0 : 0}
k = 0
for i in range(nslave - 1):
k = ring_map[k][1]
rmap[k] = i + 1
ring_map_ = {}
tree_map_ = {}
parent_map_ ={}
for k, v in ring_map.items():
ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]])
for k, v in tree_map.items():
tree_map_[rmap[k]] = [rmap[x] for x in v]
for k, v in parent_map.items():
if k != 0:
parent_map_[rmap[k]] = rmap[v]
else:
parent_map_[rmap[k]] = -1
return tree_map_, parent_map_, ring_map_
def handle_print(self,slave, msg): def handle_print(self,slave, msg):
sys.stdout.write(msg) sys.stdout.write(msg)
def log_print(self, msg, level): def log_print(self, msg, level):
if level == 1: if level == 1:
if self.verbose: if self.verbose:
sys.stderr.write(msg + '\n') sys.stderr.write(msg + '\n')
else: else:
sys.stderr.write(msg + '\n') sys.stderr.write(msg + '\n')
def accept_slaves(self, nslave): def accept_slaves(self, nslave):
# set of nodes that finishs the job # set of nodes that finishs the job
shutdown = {} shutdown = {}
@ -241,31 +273,40 @@ class Tracker:
assert s.cmd == 'start' assert s.cmd == 'start'
if s.world_size > 0: if s.world_size > 0:
nslave = s.world_size nslave = s.world_size
tree_map, parent_map = self.get_tree(nslave) tree_map, parent_map, ring_map = self.get_link_map(nslave)
ring_map = self.get_ring(tree_map, parent_map)
# set of nodes that is pending for getting up # set of nodes that is pending for getting up
todo_nodes = range(nslave) todo_nodes = range(nslave)
random.shuffle(todo_nodes)
else: else:
assert s.world_size == -1 or s.world_size == nslave assert s.world_size == -1 or s.world_size == nslave
if s.cmd == 'recover': if s.cmd == 'recover':
assert s.rank >= 0 assert s.rank >= 0
rank = s.decide_rank(job_map) rank = s.decide_rank(job_map)
# batch assignment of ranks
if rank == -1: if rank == -1:
assert len(todo_nodes) != 0 assert len(todo_nodes) != 0
rank = todo_nodes.pop(0) pending.append(s)
if s.jobid != 'NULL': if len(pending) == len(todo_nodes):
job_map[s.jobid] = rank pending.sort(key = lambda x : x.host)
for s in pending:
rank = todo_nodes.pop(0)
if s.jobid != 'NULL':
job_map[s.jobid] = rank
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
if s.wait_accept > 0:
wait_conn[rank] = s
self.log_print('Recieve %s signal from %s; assign rank %d' % (s.cmd, s.host, s.rank), 1)
if len(todo_nodes) == 0: if len(todo_nodes) == 0:
self.log_print('@tracker All of %d nodes getting started' % nslave, 2) self.log_print('@tracker All of %d nodes getting started' % nslave, 2)
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) self.start_time = time.time()
if s.cmd != 'start':
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank), 1)
else: else:
self.log_print('Recieve %s signal from %s; assign rank %d' % (s.cmd, s.host, s.rank), 1) s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
if s.wait_accept > 0: self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank), 1)
wait_conn[rank] = s if s.wait_accept > 0:
wait_conn[rank] = s
self.log_print('@tracker All nodes finishes job', 2) self.log_print('@tracker All nodes finishes job', 2)
self.end_time = time.time()
self.log_print('@tracker %s secs between node start and job finish' % str(self.end_time - self.start_time), 2)
def submit(nslave, args, fun_submit, verbose, hostIP = 'auto'): def submit(nslave, args, fun_submit, verbose, hostIP = 'auto'):
master = Tracker(verbose = verbose, hostIP = hostIP) master = Tracker(verbose = verbose, hostIP = hostIP)

View File

@ -1,4 +1,4 @@
#!/usr/bin/python #!/usr/bin/env python
""" """
This is a script to submit rabit job via Yarn This is a script to submit rabit job via Yarn
rabit will run as a Yarn application rabit will run as a Yarn application
@ -13,6 +13,7 @@ import rabit_tracker as tracker
WRAPPER_PATH = os.path.dirname(__file__) + '/../wrapper' WRAPPER_PATH = os.path.dirname(__file__) + '/../wrapper'
YARN_JAR_PATH = os.path.dirname(__file__) + '/../yarn/rabit-yarn.jar' YARN_JAR_PATH = os.path.dirname(__file__) + '/../yarn/rabit-yarn.jar'
YARN_BOOT_PY = os.path.dirname(__file__) + '/../yarn/run_hdfs_prog.py'
if not os.path.exists(YARN_JAR_PATH): if not os.path.exists(YARN_JAR_PATH):
warnings.warn("cannot find \"%s\", I will try to run build" % YARN_JAR_PATH) warnings.warn("cannot find \"%s\", I will try to run build" % YARN_JAR_PATH)
@ -21,7 +22,7 @@ if not os.path.exists(YARN_JAR_PATH):
subprocess.check_call(cmd, shell = True, env = os.environ) subprocess.check_call(cmd, shell = True, env = os.environ)
assert os.path.exists(YARN_JAR_PATH), "failed to build rabit-yarn.jar, try it manually" assert os.path.exists(YARN_JAR_PATH), "failed to build rabit-yarn.jar, try it manually"
hadoop_binary = 'hadoop' hadoop_binary = None
# code # code
hadoop_home = os.getenv('HADOOP_HOME') hadoop_home = os.getenv('HADOOP_HOME')
@ -38,6 +39,8 @@ parser.add_argument('-hip', '--host_ip', default='auto', type=str,
help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine') help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine')
parser.add_argument('-v', '--verbose', default=0, choices=[0, 1], type=int, parser.add_argument('-v', '--verbose', default=0, choices=[0, 1], type=int,
help = 'print more messages into the console') help = 'print more messages into the console')
parser.add_argument('-q', '--queue', default='default', type=str,
help = 'the queue we want to submit the job to')
parser.add_argument('-ac', '--auto_file_cache', default=1, choices=[0, 1], type=int, parser.add_argument('-ac', '--auto_file_cache', default=1, choices=[0, 1], type=int,
help = 'whether automatically cache the files in the command to hadoop localfile, this is on by default') help = 'whether automatically cache the files in the command to hadoop localfile, this is on by default')
parser.add_argument('-f', '--files', default = [], action='append', parser.add_argument('-f', '--files', default = [], action='append',
@ -56,6 +59,11 @@ parser.add_argument('-mem', '--memory_mb', default=1024, type=int,
help = 'maximum memory used by the process. Guide: set it large (near mapred.cluster.max.map.memory.mb)'\ help = 'maximum memory used by the process. Guide: set it large (near mapred.cluster.max.map.memory.mb)'\
'if you are running multi-threading rabit,'\ 'if you are running multi-threading rabit,'\
'so that each node can occupy all the mapper slots in a machine for maximum performance') 'so that each node can occupy all the mapper slots in a machine for maximum performance')
parser.add_argument('--libhdfs-opts', default='-Xmx128m', type=str,
help = 'setting to be passed to libhdfs')
parser.add_argument('--name-node', default='default', type=str,
help = 'the namenode address of hdfs, libhdfs should connect to, normally leave it as default')
parser.add_argument('command', nargs='+', parser.add_argument('command', nargs='+',
help = 'command for rabit program') help = 'command for rabit program')
args = parser.parse_args() args = parser.parse_args()
@ -87,7 +95,7 @@ if hadoop_version < 2:
print 'Current Hadoop Version is %s, rabit_yarn will need Yarn(Hadoop 2.0)' % out[1] print 'Current Hadoop Version is %s, rabit_yarn will need Yarn(Hadoop 2.0)' % out[1]
def submit_yarn(nworker, worker_args, worker_env): def submit_yarn(nworker, worker_args, worker_env):
fset = set([YARN_JAR_PATH]) fset = set([YARN_JAR_PATH, YARN_BOOT_PY])
if args.auto_file_cache != 0: if args.auto_file_cache != 0:
for i in range(len(args.command)): for i in range(len(args.command)):
f = args.command[i] f = args.command[i]
@ -96,7 +104,7 @@ def submit_yarn(nworker, worker_args, worker_env):
if i == 0: if i == 0:
args.command[i] = './' + args.command[i].split('/')[-1] args.command[i] = './' + args.command[i].split('/')[-1]
else: else:
args.command[i] = args.command[i].split('/')[-1] args.command[i] = './' + args.command[i].split('/')[-1]
if args.command[0].endswith('.py'): if args.command[0].endswith('.py'):
flst = [WRAPPER_PATH + '/rabit.py', flst = [WRAPPER_PATH + '/rabit.py',
WRAPPER_PATH + '/librabit_wrapper.so', WRAPPER_PATH + '/librabit_wrapper.so',
@ -112,6 +120,8 @@ def submit_yarn(nworker, worker_args, worker_env):
env['rabit_cpu_vcores'] = str(args.vcores) env['rabit_cpu_vcores'] = str(args.vcores)
env['rabit_memory_mb'] = str(args.memory_mb) env['rabit_memory_mb'] = str(args.memory_mb)
env['rabit_world_size'] = str(args.nworker) env['rabit_world_size'] = str(args.nworker)
env['rabit_hdfs_opts'] = str(args.libhdfs_opts)
env['rabit_hdfs_namenode'] = str(args.name_node)
if args.files != None: if args.files != None:
for flst in args.files: for flst in args.files:
@ -121,7 +131,8 @@ def submit_yarn(nworker, worker_args, worker_env):
cmd += ' -file %s' % f cmd += ' -file %s' % f
cmd += ' -jobname %s ' % args.jobname cmd += ' -jobname %s ' % args.jobname
cmd += ' -tempdir %s ' % args.tempdir cmd += ' -tempdir %s ' % args.tempdir
cmd += (' '.join(args.command + worker_args)) cmd += ' -queue %s ' % args.queue
cmd += (' '.join(['./run_hdfs_prog.py'] + args.command + worker_args))
if args.verbose != 0: if args.verbose != 0:
print cmd print cmd
subprocess.check_call(cmd, shell = True, env = env) subprocess.check_call(cmd, shell = True, env = env)

View File

@ -1 +0,0 @@
foler used to hold generated class files

View File

@ -1,8 +1,8 @@
#!/bin/bash #!/bin/bash
if [ -z "$HADOOP_PREFIX" ]; then if [ ! -d bin ]; then
echo "cannot found $HADOOP_PREFIX in the environment variable, please set it properly" mkdir bin
exit 1
fi fi
CPATH=`${HADOOP_PREFIX}/bin/hadoop classpath`
CPATH=`${HADOOP_HOME}/bin/hadoop classpath`
javac -cp $CPATH -d bin src/org/apache/hadoop/yarn/rabit/* javac -cp $CPATH -d bin src/org/apache/hadoop/yarn/rabit/*
jar cf rabit-yarn.jar -C bin . jar cf rabit-yarn.jar -C bin .

View File

@ -0,0 +1,45 @@
#!/usr/bin/env python
"""
this script helps setup classpath env for HDFS, before running program
that links with libhdfs
"""
import glob
import sys
import os
import subprocess
if len(sys.argv) < 2:
print 'Usage: the command you want to run'
hadoop_home = os.getenv('HADOOP_HOME')
hdfs_home = os.getenv('HADOOP_HDFS_HOME')
java_home = os.getenv('JAVA_HOME')
if hadoop_home is None:
hadoop_home = os.getenv('HADOOP_PREFIX')
assert hadoop_home is not None, 'need to set HADOOP_HOME'
assert hdfs_home is not None, 'need to set HADOOP_HDFS_HOME'
assert java_home is not None, 'need to set JAVA_HOME'
(classpath, err) = subprocess.Popen('%s/bin/hadoop classpath' % hadoop_home,
stdout=subprocess.PIPE, shell = True,
env = os.environ).communicate()
cpath = []
for f in classpath.split(':'):
cpath += glob.glob(f)
lpath = []
lpath.append('%s/lib/native' % hdfs_home)
lpath.append('%s/jre/lib/amd64/server' % java_home)
env = os.environ.copy()
env['CLASSPATH'] = '${CLASSPATH}:' + (':'.join(cpath))
# setup hdfs options
if 'rabit_hdfs_opts' in env:
env['LIBHDFS_OPTS'] = env['rabit_hdfs_opts']
elif 'LIBHDFS_OPTS' not in env:
env['LIBHDFS_OPTS'] = '--Xmx128m'
env['LD_LIBRARY_PATH'] = '${LD_LIBRARY_PATH}:' + (':'.join(lpath))
ret = subprocess.call(args = sys.argv[1:], env = env)
sys.exit(ret)

View File

@ -1,5 +1,6 @@
package org.apache.hadoop.yarn.rabit; package org.apache.hadoop.yarn.rabit;
import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.List; import java.util.List;
@ -7,6 +8,7 @@ import java.util.Map;
import java.util.Queue; import java.util.Queue;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
@ -34,6 +36,7 @@ import org.apache.hadoop.yarn.api.records.NodeReport;
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest; import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest;
import org.apache.hadoop.yarn.client.api.async.NMClientAsync; import org.apache.hadoop.yarn.client.api.async.NMClientAsync;
import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync; import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
import org.apache.hadoop.security.UserGroupInformation;
/** /**
* application master for allocating resources of rabit client * application master for allocating resources of rabit client
@ -61,6 +64,8 @@ public class ApplicationMaster {
// command to launch // command to launch
private String command = ""; private String command = "";
// username
private String userName = "";
// application tracker hostname // application tracker hostname
private String appHostName = ""; private String appHostName = "";
// tracker URL to do // tracker URL to do
@ -128,6 +133,8 @@ public class ApplicationMaster {
*/ */
private void initArgs(String args[]) throws IOException { private void initArgs(String args[]) throws IOException {
LOG.info("Invoke initArgs"); LOG.info("Invoke initArgs");
// get user name
userName = UserGroupInformation.getCurrentUser().getShortUserName();
// cached maps // cached maps
Map<String, Path> cacheFiles = new java.util.HashMap<String, Path>(); Map<String, Path> cacheFiles = new java.util.HashMap<String, Path>();
for (int i = 0; i < args.length; ++i) { for (int i = 0; i < args.length; ++i) {
@ -156,7 +163,8 @@ public class ApplicationMaster {
numVCores = this.getEnvInteger("rabit_cpu_vcores", true, numVCores); numVCores = this.getEnvInteger("rabit_cpu_vcores", true, numVCores);
numMemoryMB = this.getEnvInteger("rabit_memory_mb", true, numMemoryMB); numMemoryMB = this.getEnvInteger("rabit_memory_mb", true, numMemoryMB);
numTasks = this.getEnvInteger("rabit_world_size", true, numTasks); numTasks = this.getEnvInteger("rabit_world_size", true, numTasks);
maxNumAttempt = this.getEnvInteger("rabit_max_attempt", false, maxNumAttempt); maxNumAttempt = this.getEnvInteger("rabit_max_attempt", false,
maxNumAttempt);
} }
/** /**
@ -175,7 +183,7 @@ public class ApplicationMaster {
RegisterApplicationMasterResponse response = this.rmClient RegisterApplicationMasterResponse response = this.rmClient
.registerApplicationMaster(this.appHostName, .registerApplicationMaster(this.appHostName,
this.appTrackerPort, this.appTrackerUrl); this.appTrackerPort, this.appTrackerUrl);
boolean success = false; boolean success = false;
String diagnostics = ""; String diagnostics = "";
try { try {
@ -208,19 +216,20 @@ public class ApplicationMaster {
assert (killedTasks.size() + finishedTasks.size() == numTasks); assert (killedTasks.size() + finishedTasks.size() == numTasks);
success = finishedTasks.size() == numTasks; success = finishedTasks.size() == numTasks;
LOG.info("Application completed. Stopping running containers"); LOG.info("Application completed. Stopping running containers");
nmClient.stop();
diagnostics = "Diagnostics." + ", num_tasks" + this.numTasks diagnostics = "Diagnostics." + ", num_tasks" + this.numTasks
+ ", finished=" + this.finishedTasks.size() + ", failed=" + ", finished=" + this.finishedTasks.size() + ", failed="
+ this.killedTasks.size() + "\n" + this.abortDiagnosis; + this.killedTasks.size() + "\n" + this.abortDiagnosis;
nmClient.stop();
LOG.info(diagnostics); LOG.info(diagnostics);
} catch (Exception e) { } catch (Exception e) {
diagnostics = e.toString(); diagnostics = e.toString();
} }
rmClient.unregisterApplicationMaster( rmClient.unregisterApplicationMaster(
success ? FinalApplicationStatus.SUCCEEDED success ? FinalApplicationStatus.SUCCEEDED
: FinalApplicationStatus.FAILED, diagnostics, : FinalApplicationStatus.FAILED, diagnostics,
appTrackerUrl); appTrackerUrl);
if (!success) throw new Exception("Application not successful"); if (!success)
throw new Exception("Application not successful");
} }
/** /**
@ -265,30 +274,63 @@ public class ApplicationMaster {
task.containerRequest = null; task.containerRequest = null;
ContainerLaunchContext ctx = Records ContainerLaunchContext ctx = Records
.newRecord(ContainerLaunchContext.class); .newRecord(ContainerLaunchContext.class);
String cmd = String hadoop = "hadoop";
// use this to setup CLASSPATH correctly for libhdfs if (System.getenv("HADOOP_HOME") != null) {
"CLASSPATH=${CLASSPATH}:`${HADOOP_PREFIX}/bin/hadoop classpath --glob` " hadoop = "${HADOOP_HOME}/bin/hadoop";
+ this.command + " 1>" } else if (System.getenv("HADOOP_PREFIX") != null) {
+ ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" hadoop = "${HADOOP_PREFIX}/bin/hadoop";
+ " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR }
+ "/stderr";
LOG.info(cmd); String cmd =
// use this to setup CLASSPATH correctly for libhdfs
this.command + " 1>"
+ ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
+ " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR
+ "/stderr";
ctx.setCommands(Collections.singletonList(cmd)); ctx.setCommands(Collections.singletonList(cmd));
LOG.info(workerResources); LOG.info(workerResources);
ctx.setLocalResources(this.workerResources); ctx.setLocalResources(this.workerResources);
// setup environment variables // setup environment variables
Map<String, String> env = new java.util.HashMap<String, String>(); Map<String, String> env = new java.util.HashMap<String, String>();
// setup class path, this is kind of duplicated, ignoring // setup class path, this is kind of duplicated, ignoring
StringBuilder cpath = new StringBuilder("${CLASSPATH}:./*"); StringBuilder cpath = new StringBuilder("${CLASSPATH}:./*");
for (String c : conf.getStrings( for (String c : conf.getStrings(
YarnConfiguration.YARN_APPLICATION_CLASSPATH, YarnConfiguration.YARN_APPLICATION_CLASSPATH,
YarnConfiguration.DEFAULT_YARN_APPLICATION_CLASSPATH)) { YarnConfiguration.DEFAULT_YARN_APPLICATION_CLASSPATH)) {
cpath.append(':'); String[] arrPath = c.split(":");
cpath.append(c.trim()); for (String ps : arrPath) {
if (ps.endsWith("*.jar") || ps.endsWith("*")) {
ps = ps.substring(0, ps.lastIndexOf('*'));
String prefix = ps.substring(0, ps.lastIndexOf('/'));
if (ps.startsWith("$")) {
String[] arr =ps.split("/", 2);
if (arr.length != 2) continue;
try {
ps = System.getenv(arr[0].substring(1)) + '/' + arr[1];
} catch (Exception e){
continue;
}
}
File dir = new File(ps);
if (dir.isDirectory()) {
for (File f: dir.listFiles()) {
if (f.isFile() && f.getPath().endsWith(".jar")) {
cpath.append(":");
cpath.append(prefix + '/' + f.getName());
}
}
}
} else {
cpath.append(':');
cpath.append(ps.trim());
}
}
} }
// already use hadoop command to get class path in worker, maybe a better solution in future // already use hadoop command to get class path in worker, maybe a
// env.put("CLASSPATH", cpath.toString()); // better solution in future
env.put("CLASSPATH", cpath.toString());
//LOG.info("CLASSPATH =" + cpath.toString());
// setup LD_LIBARY_PATH path for libhdfs // setup LD_LIBARY_PATH path for libhdfs
env.put("LD_LIBRARY_PATH", env.put("LD_LIBRARY_PATH",
"${LD_LIBRARY_PATH}:$HADOOP_HDFS_HOME/lib/native:$JAVA_HOME/jre/lib/amd64/server"); "${LD_LIBRARY_PATH}:$HADOOP_HDFS_HOME/lib/native:$JAVA_HOME/jre/lib/amd64/server");
@ -298,10 +340,13 @@ public class ApplicationMaster {
if (e.getKey().startsWith("rabit_")) { if (e.getKey().startsWith("rabit_")) {
env.put(e.getKey(), e.getValue()); env.put(e.getKey(), e.getValue());
} }
if (e.getKey() == "LIBHDFS_OPTS") {
env.put(e.getKey(), e.getValue());
}
} }
env.put("rabit_task_id", String.valueOf(task.taskId)); env.put("rabit_task_id", String.valueOf(task.taskId));
env.put("rabit_num_trial", String.valueOf(task.attemptCounter)); env.put("rabit_num_trial", String.valueOf(task.attemptCounter));
// ctx.setUser(userName);
ctx.setEnvironment(env); ctx.setEnvironment(env);
synchronized (this) { synchronized (this) {
assert (!this.runningTasks.containsKey(container.getId())); assert (!this.runningTasks.containsKey(container.getId()));
@ -376,8 +421,17 @@ public class ApplicationMaster {
Collection<TaskRecord> tasks = new java.util.LinkedList<TaskRecord>(); Collection<TaskRecord> tasks = new java.util.LinkedList<TaskRecord>();
for (ContainerId cid : failed) { for (ContainerId cid : failed) {
TaskRecord r = runningTasks.remove(cid); TaskRecord r = runningTasks.remove(cid);
if (r == null) if (r == null) {
continue; continue;
}
LOG.info("Task "
+ r.taskId
+ "failed on "
+ r.container.getId()
+ ". See LOG at : "
+ String.format("http://%s/node/containerlogs/%s/"
+ userName, r.container.getNodeHttpAddress(),
r.container.getId()));
r.attemptCounter += 1; r.attemptCounter += 1;
r.container = null; r.container = null;
tasks.add(r); tasks.add(r);
@ -411,22 +465,26 @@ public class ApplicationMaster {
finishedTasks.add(r); finishedTasks.add(r);
runningTasks.remove(s.getContainerId()); runningTasks.remove(s.getContainerId());
} else { } else {
switch (exstatus) { try {
case ContainerExitStatus.KILLED_EXCEEDED_PMEM: if (exstatus == ContainerExitStatus.class.getField(
this.abortJob("[Rabit] Task " "KILLED_EXCEEDED_PMEM").getInt(null)) {
+ r.taskId this.abortJob("[Rabit] Task "
+ " killed because of exceeding allocated physical memory"); + r.taskId
break; + " killed because of exceeding allocated physical memory");
case ContainerExitStatus.KILLED_EXCEEDED_VMEM: continue;
this.abortJob("[Rabit] Task " }
+ r.taskId if (exstatus == ContainerExitStatus.class.getField(
+ " killed because of exceeding allocated virtual memory"); "KILLED_EXCEEDED_VMEM").getInt(null)) {
break; this.abortJob("[Rabit] Task "
default: + r.taskId
LOG.info("[Rabit] Task " + r.taskId + " killed because of exceeding allocated virtual memory");
+ " exited with status " + exstatus); continue;
failed.add(s.getContainerId()); }
} catch (Exception e) {
} }
LOG.info("[Rabit] Task " + r.taskId + " exited with status "
+ exstatus + " Diagnostics:"+ s.getDiagnostics());
failed.add(s.getContainerId());
} }
} }
this.handleFailure(failed); this.handleFailure(failed);

View File

@ -9,6 +9,7 @@ import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.permission.FsPermission; import org.apache.hadoop.fs.permission.FsPermission;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.yarn.api.ApplicationConstants; import org.apache.hadoop.yarn.api.ApplicationConstants;
import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.ApplicationReport; import org.apache.hadoop.yarn.api.records.ApplicationReport;
@ -19,6 +20,7 @@ import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.api.records.LocalResourceType; import org.apache.hadoop.yarn.api.records.LocalResourceType;
import org.apache.hadoop.yarn.api.records.LocalResourceVisibility; import org.apache.hadoop.yarn.api.records.LocalResourceVisibility;
import org.apache.hadoop.yarn.api.records.Resource; import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.api.records.QueueInfo;
import org.apache.hadoop.yarn.api.records.YarnApplicationState; import org.apache.hadoop.yarn.api.records.YarnApplicationState;
import org.apache.hadoop.yarn.client.api.YarnClient; import org.apache.hadoop.yarn.client.api.YarnClient;
import org.apache.hadoop.yarn.client.api.YarnClientApplication; import org.apache.hadoop.yarn.client.api.YarnClientApplication;
@ -43,14 +45,21 @@ public class Client {
private String appArgs = ""; private String appArgs = "";
// HDFS Path to store temporal result // HDFS Path to store temporal result
private String tempdir = "/tmp"; private String tempdir = "/tmp";
// user name
private String userName = "";
// job name // job name
private String jobName = ""; private String jobName = "";
// queue
private String queue = "default";
/** /**
* constructor * constructor
* @throws IOException * @throws IOException
*/ */
private Client() throws IOException { private Client() throws IOException {
conf.addResource(new Path(System.getenv("HADOOP_CONF_DIR") +"/core-site.xml"));
conf.addResource(new Path(System.getenv("HADOOP_CONF_DIR") +"/hdfs-site.xml"));
dfs = FileSystem.get(conf); dfs = FileSystem.get(conf);
userName = UserGroupInformation.getCurrentUser().getShortUserName();
} }
/** /**
@ -127,6 +136,9 @@ public class Client {
if (e.getKey().startsWith("rabit_")) { if (e.getKey().startsWith("rabit_")) {
env.put(e.getKey(), e.getValue()); env.put(e.getKey(), e.getValue());
} }
if (e.getKey() == "LIBHDFS_OPTS") {
env.put(e.getKey(), e.getValue());
}
} }
LOG.debug(env); LOG.debug(env);
return env; return env;
@ -152,6 +164,8 @@ public class Client {
this.jobName = args[++i]; this.jobName = args[++i];
} else if(args[i].equals("-tempdir")) { } else if(args[i].equals("-tempdir")) {
this.tempdir = args[++i]; this.tempdir = args[++i];
} else if(args[i].equals("-queue")) {
this.queue = args[++i];
} else { } else {
sargs.append(" "); sargs.append(" ");
sargs.append(args[i]); sargs.append(args[i]);
@ -168,7 +182,6 @@ public class Client {
} }
this.initArgs(args); this.initArgs(args);
// Create yarnClient // Create yarnClient
YarnConfiguration conf = new YarnConfiguration();
YarnClient yarnClient = YarnClient.createYarnClient(); YarnClient yarnClient = YarnClient.createYarnClient();
yarnClient.init(conf); yarnClient.init(conf);
yarnClient.start(); yarnClient.start();
@ -181,13 +194,14 @@ public class Client {
.newRecord(ContainerLaunchContext.class); .newRecord(ContainerLaunchContext.class);
ApplicationSubmissionContext appContext = app ApplicationSubmissionContext appContext = app
.getApplicationSubmissionContext(); .getApplicationSubmissionContext();
// Submit application // Submit application
ApplicationId appId = appContext.getApplicationId(); ApplicationId appId = appContext.getApplicationId();
// setup cache-files and environment variables // setup cache-files and environment variables
amContainer.setLocalResources(this.setupCacheFiles(appId)); amContainer.setLocalResources(this.setupCacheFiles(appId));
amContainer.setEnvironment(this.getEnvironment()); amContainer.setEnvironment(this.getEnvironment());
String cmd = "$JAVA_HOME/bin/java" String cmd = "$JAVA_HOME/bin/java"
+ " -Xmx256M" + " -Xmx900M"
+ " org.apache.hadoop.yarn.rabit.ApplicationMaster" + " org.apache.hadoop.yarn.rabit.ApplicationMaster"
+ this.cacheFileArg + ' ' + this.appArgs + " 1>" + this.cacheFileArg + ' ' + this.appArgs + " 1>"
+ ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
@ -197,15 +211,15 @@ public class Client {
// Set up resource type requirements for ApplicationMaster // Set up resource type requirements for ApplicationMaster
Resource capability = Records.newRecord(Resource.class); Resource capability = Records.newRecord(Resource.class);
capability.setMemory(256); capability.setMemory(1024);
capability.setVirtualCores(1); capability.setVirtualCores(1);
LOG.info("jobname=" + this.jobName); LOG.info("jobname=" + this.jobName + ",username=" + this.userName);
appContext.setApplicationName(jobName + ":RABIT-YARN"); appContext.setApplicationName(jobName + ":RABIT-YARN");
appContext.setAMContainerSpec(amContainer); appContext.setAMContainerSpec(amContainer);
appContext.setResource(capability); appContext.setResource(capability);
appContext.setQueue("default"); appContext.setQueue(queue);
//appContext.setUser(userName);
LOG.info("Submitting application " + appId); LOG.info("Submitting application " + appId);
yarnClient.submitApplication(appContext); yarnClient.submitApplication(appContext);
@ -218,12 +232,16 @@ public class Client {
appReport = yarnClient.getApplicationReport(appId); appReport = yarnClient.getApplicationReport(appId);
appState = appReport.getYarnApplicationState(); appState = appReport.getYarnApplicationState();
} }
System.out.println("Application " + appId + " finished with" System.out.println("Application " + appId + " finished with"
+ " state " + appState + " at " + appReport.getFinishTime()); + " state " + appState + " at " + appReport.getFinishTime());
if (!appReport.getFinalApplicationStatus().equals( if (!appReport.getFinalApplicationStatus().equals(
FinalApplicationStatus.SUCCEEDED)) { FinalApplicationStatus.SUCCEEDED)) {
System.err.println(appReport.getDiagnostics()); System.err.println(appReport.getDiagnostics());
System.out.println("Available queues:");
for (QueueInfo q : yarnClient.getAllQueues()) {
System.out.println(q.getQueueName());
}
} }
} }