Merge commit '75bf97b57539e5572e7ae8eba72bac6562c63c07'

Conflicts:
	subtree/rabit/rabit-learn/io/line_split-inl.h
	subtree/rabit/yarn/build.sh
This commit is contained in:
tqchen 2015-03-21 00:48:34 -07:00
commit 9ccbeaa8f0
34 changed files with 856 additions and 201 deletions

View File

@ -19,6 +19,8 @@ namespace utils {
/*! \brief interface of i/o stream that support seek */
class ISeekStream: public IStream {
public:
// virtual destructor
virtual ~ISeekStream(void) {}
/*! \brief seek to certain position of the file */
virtual void Seek(size_t pos) = 0;
/*! \brief tell the position of the stream */

2
subtree/rabit/rabit-learn/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
config.mk
*.log

View File

@ -38,6 +38,7 @@ class StreamBufferReader {
}
}
}
/*! \brief whether we are reaching the end of file */
inline bool AtEnd(void) const {
return read_len_ == 0;
}

View File

@ -66,27 +66,36 @@ class FileStream : public utils::ISeekStream {
};
/*! \brief line split from normal file system */
class FileSplit : public LineSplitBase {
class FileProvider : public LineSplitter::IFileProvider {
public:
explicit FileSplit(const char *uri, unsigned rank, unsigned nsplit) {
LineSplitBase::SplitNames(&fnames_, uri, "#");
explicit FileProvider(const char *uri) {
LineSplitter::SplitNames(&fnames_, uri, "#");
std::vector<size_t> fsize;
for (size_t i = 0; i < fnames_.size(); ++i) {
if (!std::strncmp(fnames_[i].c_str(), "file://", 7)) {
std::string tmp = fnames_[i].c_str() + 7;
fnames_[i] = tmp;
}
fsize.push_back(GetFileSize(fnames_[i].c_str()));
size_t fz = GetFileSize(fnames_[i].c_str());
if (fz != 0) {
fsize_.push_back(fz);
}
LineSplitBase::Init(fsize, rank, nsplit);
}
virtual ~FileSplit(void) {}
protected:
virtual utils::ISeekStream *GetFile(size_t file_index) {
}
// destrucor
virtual ~FileProvider(void) {}
virtual utils::ISeekStream *Open(size_t file_index) {
utils::Assert(file_index < fnames_.size(), "file index exceed bound");
return new FileStream(fnames_[file_index].c_str(), "rb");
}
virtual const std::vector<size_t> &FileSize(void) const {
return fsize_;
}
private:
// file sizes
std::vector<size_t> fsize_;
// file names
std::vector<std::string> fnames_;
// get file size
inline static size_t GetFileSize(const char *fname) {
std::FILE *fp = utils::FopenCheck(fname, "rb");
@ -96,10 +105,6 @@ class FileSplit : public LineSplitBase {
std::fclose(fp);
return fsize;
}
private:
// file names
std::vector<std::string> fnames_;
};
} // namespace io
} // namespace rabit

View File

@ -6,6 +6,7 @@
* \author Tianqi Chen
*/
#include <string>
#include <cstdlib>
#include <vector>
#include <hdfs.h>
#include <errno.h>
@ -15,11 +16,15 @@
/*! \brief io interface */
namespace rabit {
namespace io {
class HDFSStream : public utils::ISeekStream {
class HDFSStream : public ISeekStream {
public:
HDFSStream(hdfsFS fs, const char *fname, const char *mode)
: fs_(fs), at_end_(false) {
int flag;
HDFSStream(hdfsFS fs,
const char *fname,
const char *mode,
bool disconnect_when_done)
: fs_(fs), at_end_(false),
disconnect_when_done_(disconnect_when_done) {
int flag = 0;
if (!strcmp(mode, "r")) {
flag = O_RDONLY;
} else if (!strcmp(mode, "w")) {
@ -35,6 +40,9 @@ class HDFSStream : public utils::ISeekStream {
}
virtual ~HDFSStream(void) {
this->Close();
if (disconnect_when_done_) {
utils::Check(hdfsDisconnect(fs_) == 0, "hdfsDisconnect error");
}
}
virtual size_t Read(void *ptr, size_t size) {
tSize nread = hdfsRead(fs_, fp_, ptr, size);
@ -86,52 +94,69 @@ class HDFSStream : public utils::ISeekStream {
}
}
inline static std::string GetNameNode(void) {
const char *nn = getenv("rabit_hdfs_namenode");
if (nn == NULL) {
return std::string("default");
} else {
return std::string(nn);
}
}
private:
hdfsFS fs_;
hdfsFile fp_;
bool at_end_;
bool disconnect_when_done_;
};
/*! \brief line split from normal file system */
class HDFSSplit : public LineSplitBase {
class HDFSProvider : public LineSplitter::IFileProvider {
public:
explicit HDFSSplit(const char *uri, unsigned rank, unsigned nsplit) {
fs_ = hdfsConnect("default", 0);
explicit HDFSProvider(const char *uri) {
fs_ = hdfsConnect(HDFSStream::GetNameNode().c_str(), 0);
utils::Check(fs_ != NULL, "error when connecting to default HDFS");
std::vector<std::string> paths;
LineSplitBase::SplitNames(&paths, uri, "#");
LineSplitter::SplitNames(&paths, uri, "#");
// get the files
std::vector<size_t> fsize;
for (size_t i = 0; i < paths.size(); ++i) {
hdfsFileInfo *info = hdfsGetPathInfo(fs_, paths[i].c_str());
utils::Check(info != NULL, "path %s do not exist", paths[i].c_str());
if (info->mKind == 'D') {
int nentry;
hdfsFileInfo *files = hdfsListDirectory(fs_, info->mName, &nentry);
utils::Check(files != NULL, "error when ListDirectory %s", info->mName);
for (int i = 0; i < nentry; ++i) {
if (files[i].mKind == 'F') {
fsize.push_back(files[i].mSize);
if (files[i].mKind == 'F' && files[i].mSize != 0) {
fsize_.push_back(files[i].mSize);
fnames_.push_back(std::string(files[i].mName));
}
}
hdfsFreeFileInfo(files, nentry);
} else {
fsize.push_back(info->mSize);
if (info->mSize != 0) {
fsize_.push_back(info->mSize);
fnames_.push_back(std::string(info->mName));
}
}
hdfsFreeFileInfo(info, 1);
}
LineSplitBase::Init(fsize, rank, nsplit);
}
virtual ~HDFSSplit(void) {}
protected:
virtual utils::ISeekStream *GetFile(size_t file_index) {
virtual ~HDFSProvider(void) {
utils::Check(hdfsDisconnect(fs_) == 0, "hdfsDisconnect error");
}
virtual const std::vector<size_t> &FileSize(void) const {
return fsize_;
}
virtual ISeekStream *Open(size_t file_index) {
utils::Assert(file_index < fnames_.size(), "file index exceed bound");
return new HDFSStream(fs_, fnames_[file_index].c_str(), "r");
return new HDFSStream(fs_, fnames_[file_index].c_str(), "r", false);
}
private:
// hdfs handle
hdfsFS fs_;
// file sizes
std::vector<size_t> fsize_;
// file names
std::vector<std::string> fnames_;
};

View File

@ -30,16 +30,16 @@ inline InputSplit *CreateInputSplit(const char *uri,
return new SingleFileSplit(uri);
}
if (!strncmp(uri, "file://", 7)) {
return new FileSplit(uri, part, nsplit);
return new LineSplitter(new FileProvider(uri), part, nsplit);
}
if (!strncmp(uri, "hdfs://", 7)) {
#if RABIT_USE_HDFS
return new HDFSSplit(uri, part, nsplit);
return new LineSplitter(new HDFSProvider(uri), part, nsplit);
#else
utils::Error("Please compile with RABIT_USE_HDFS=1");
#endif
}
return new FileSplit(uri, part, nsplit);
return new LineSplitter(new FileProvider(uri), part, nsplit);
}
/*!
* \brief create an stream, the stream must be able to close
@ -55,7 +55,8 @@ inline IStream *CreateStream(const char *uri, const char *mode) {
}
if (!strncmp(uri, "hdfs://", 7)) {
#if RABIT_USE_HDFS
return new HDFSStream(hdfsConnect("default", 0), uri, mode);
return new HDFSStream(hdfsConnect(HDFSStream::GetNameNode().c_str(), 0),
uri, mode, true);
#else
utils::Error("Please compile with RABIT_USE_HDFS=1");
#endif

View File

@ -19,6 +19,7 @@ namespace rabit {
* \brief namespace to handle input split and filesystem interfacing
*/
namespace io {
/*! \brief reused ISeekStream's definition */
typedef utils::ISeekStream ISeekStream;
/*!
* \brief user facing input split helper,

View File

@ -15,11 +15,42 @@
namespace rabit {
namespace io {
class LineSplitBase : public InputSplit {
/*! \brief class that split the files by line */
class LineSplitter : public InputSplit {
public:
virtual ~LineSplitBase() {
if (fs_ != NULL) delete fs_;
class IFileProvider {
public:
/*!
* \brief get the seek stream of given file_index
* \return the corresponding seek stream at head of the stream
* the seek stream's resource can be freed by calling delete
*/
virtual ISeekStream *Open(size_t file_index) = 0;
/*!
* \return const reference to size of each files
*/
virtual const std::vector<size_t> &FileSize(void) const = 0;
// virtual destructor
virtual ~IFileProvider() {}
};
// constructor
explicit LineSplitter(IFileProvider *provider,
unsigned rank,
unsigned nsplit)
: provider_(provider), fs_(NULL),
reader_(kBufferSize) {
this->Init(provider_->FileSize(), rank, nsplit);
}
// destructor
virtual ~LineSplitter() {
if (fs_ != NULL) {
delete fs_; fs_ = NULL;
}
// delete provider after destructing the streams
delete provider_;
}
// get next line
virtual bool NextLine(std::string *out_data) {
if (file_ptr_ >= file_ptr_end_ &&
offset_curr_ >= offset_end_) return false;
@ -29,15 +60,15 @@ class LineSplitBase : public InputSplit {
if (reader_.AtEnd()) {
if (out_data->length() != 0) return true;
file_ptr_ += 1;
if (offset_curr_ >= offset_end_) return false;
if (offset_curr_ != file_offset_[file_ptr_]) {
utils::Error("warning:std::FILE size not calculated correctly\n");
utils::Error("warning: FILE size not calculated correctly\n");
offset_curr_ = file_offset_[file_ptr_];
}
if (offset_curr_ >= offset_end_) return false;
utils::Assert(file_ptr_ + 1 < file_offset_.size(),
"boundary check");
delete fs_;
fs_ = this->GetFile(file_ptr_);
fs_ = provider_->Open(file_ptr_);
reader_.set_stream(fs_);
} else {
++offset_curr_;
@ -51,15 +82,27 @@ class LineSplitBase : public InputSplit {
}
}
}
protected:
// constructor
LineSplitBase(void)
: fs_(NULL), reader_(kBufferSize) {
/*!
* \brief split names given
* \param out_fname output std::FILE names
* \param uri_ the iput uri std::FILE
* \param dlm deliminetr
*/
inline static void SplitNames(std::vector<std::string> *out_fname,
const char *uri_,
const char *dlm) {
std::string uri = uri_;
char *p = std::strtok(BeginPtr(uri), dlm);
while (p != NULL) {
out_fname->push_back(std::string(p));
p = std::strtok(NULL, dlm);
}
}
private:
/*!
* \brief initialize the line spliter,
* \param file_size, size of each std::FILEs
* \param file_size, size of each files
* \param rank the current rank of the data
* \param nsplit number of split we will divide the data into
*/
@ -82,7 +125,7 @@ class LineSplitBase : public InputSplit {
file_ptr_end_ = std::upper_bound(file_offset_.begin(),
file_offset_.end(),
offset_end_) - file_offset_.begin() - 1;
fs_ = GetFile(file_ptr_);
fs_ = provider_->Open(file_ptr_);
reader_.set_stream(fs_);
// try to set the starting position correctly
if (file_offset_[file_ptr_] != offset_begin_) {
@ -94,33 +137,15 @@ class LineSplitBase : public InputSplit {
}
}
}
/*!
* \brief get the seek stream of given file_index
* \return the corresponding seek stream at head of std::FILE
*/
virtual utils::ISeekStream *GetFile(size_t file_index) = 0;
/*!
* \brief split names given
* \param out_fname output std::FILE names
* \param uri_ the iput uri std::FILE
* \param dlm deliminetr
*/
inline static void SplitNames(std::vector<std::string> *out_fname,
const char *uri_,
const char *dlm) {
std::string uri = uri_;
char *p = std::strtok(BeginPtr(uri), dlm);
while (p != NULL) {
out_fname->push_back(std::string(p));
p = std::strtok(NULL, dlm);
}
}
private:
/*! \brief FileProvider */
IFileProvider *provider_;
/*! \brief current input stream */
utils::ISeekStream *fs_;
/*! \brief std::FILE pointer of which std::FILE to read on */
/*! \brief file pointer of which file to read on */
size_t file_ptr_;
/*! \brief std::FILE pointer where the end of std::FILE lies */
/*! \brief file pointer where the end of file lies */
size_t file_ptr_end_;
/*! \brief get the current offset */
size_t offset_curr_;
@ -128,7 +153,7 @@ class LineSplitBase : public InputSplit {
size_t offset_begin_;
/*! \brief end of the offset */
size_t offset_end_;
/*! \brief byte-offset of each std::FILE */
/*! \brief byte-offset of each file */
std::vector<size_t> file_offset_;
/*! \brief buffer reader */
StreamBufferReader reader_;

View File

@ -1,4 +1,10 @@
# specify tensor path
ifneq ("$(wildcard ../config.mk)","")
config = ../config.mk
else
config = ../make/config.mk
endif
include $(config)
BIN = linear.rabit
MOCKBIN= linear.mock
MPIBIN =
@ -6,10 +12,10 @@ MPIBIN =
OBJ = linear.o
# common build script for programs
include ../make/config.mk
include ../make/common.mk
CFLAGS+=-fopenmp
linear.o: linear.cc ../../src/*.h linear.h ../solver/*.h
# dependenies here
linear.rabit: linear.o lib
linear.mock: linear.o lib

View File

@ -206,21 +206,22 @@ int main(int argc, char *argv[]) {
rabit::Finalize();
return 0;
}
rabit::linear::LinearObjFunction linear;
rabit::linear::LinearObjFunction *linear = new rabit::linear::LinearObjFunction();
if (!strcmp(argv[1], "stdin")) {
linear.LoadData(argv[1]);
linear->LoadData(argv[1]);
rabit::Init(argc, argv);
} else {
rabit::Init(argc, argv);
linear.LoadData(argv[1]);
linear->LoadData(argv[1]);
}
for (int i = 2; i < argc; ++i) {
char name[256], val[256];
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
linear.SetParam(name, val);
linear->SetParam(name, val);
}
}
linear.Run();
linear->Run();
delete linear;
rabit::Finalize();
return 0;
}

View File

@ -26,10 +26,11 @@ struct LinearModel {
int reserved[16];
// constructor
ModelParam(void) {
memset(this, 0, sizeof(ModelParam));
base_score = 0.5f;
num_feature = 0;
loss_type = 1;
std::memset(reserved, 0, sizeof(reserved));
num_feature = 0;
}
// initialize base score
inline void InitBaseScore(void) {
@ -119,7 +120,7 @@ struct LinearModel {
}
fi.Read(weight, sizeof(float) * (param.num_feature + 1));
}
inline void Save(rabit::IStream &fo, const float *wptr = NULL) const {
inline void Save(rabit::IStream &fo, const float *wptr = NULL) {
fo.Write(&param, sizeof(param));
if (wptr == NULL) wptr = weight;
fo.Write(wptr, sizeof(float) * (param.num_feature + 1));

View File

@ -6,12 +6,13 @@ then
fi
# put the local training file to HDFS
hadoop fs -rm -r -f $2/data
hadoop fs -rm -r -f $2/mushroom.linear.model
hadoop fs -mkdir $2/data
hadoop fs -put ../data/agaricus.txt.train $2/data
# submit to hadoop
../../tracker/rabit_yarn.py -n $1 --vcores 1 linear.rabit hdfs://$2/data/agaricus.txt.train model_out=hdfs://$2/mushroom.linear.model "${*:3}"
../../tracker/rabit_yarn.py -n $1 --vcores 1 ./linear.rabit hdfs://$2/data/agaricus.txt.train model_out=hdfs://$2/mushroom.linear.model "${*:3}"
# get the final model file
hadoop fs -get $2/mushroom.linear.model ./linear.model

View File

@ -6,7 +6,7 @@
#
# - copy this file to the root of rabit-learn folder
# - modify the configuration you want
# - type make or make -j n for parallel build
# - type make or make -j n on each of the folder
#----------------------------------------------------
# choice of compiler

View File

@ -145,8 +145,9 @@ class LBFGSSolver {
if (silent == 0 && rabit::GetRank() == 0) {
rabit::TrackerPrintf
("L-BFGS solver starts, num_dim=%lu, init_objval=%g, size_memory=%lu\n",
gstate.num_dim, gstate.init_objval, gstate.size_memory);
("L-BFGS solver starts, num_dim=%lu, init_objval=%g, size_memory=%lu, RAM-approx=%lu\n",
gstate.num_dim, gstate.init_objval, gstate.size_memory,
gstate.MemCost() + hist.MemCost());
}
}
}
@ -176,7 +177,7 @@ class LBFGSSolver {
// swap new weight
std::swap(g.weight, g.grad);
// check stop condition
if (gstate.num_iteration > min_lbfgs_iter) {
if (gstate.num_iteration > static_cast<size_t>(min_lbfgs_iter)) {
if (g.old_objval - g.new_objval < lbfgs_stop_tol * g.init_objval) {
return true;
}
@ -195,7 +196,7 @@ class LBFGSSolver {
/*! \brief run optimization */
virtual void Run(void) {
this->Init();
while (gstate.num_iteration < max_lbfgs_iter) {
while (gstate.num_iteration < static_cast<size_t>(max_lbfgs_iter)) {
if (this->UpdateOneIter()) break;
}
if (silent == 0 && rabit::GetRank() == 0) {
@ -225,7 +226,7 @@ class LBFGSSolver {
const size_t num_dim = gstate.num_dim;
const DType *gsub = grad + range_begin_;
const size_t nsub = range_end_ - range_begin_;
double vdot;
double vdot = 0.0;
if (n != 0) {
// hist[m + n - 1] stores old gradient
Minus(hist[m + n - 1], gsub, hist[m + n - 1], nsub);
@ -241,15 +242,19 @@ class LBFGSSolver {
idxset.push_back(std::make_pair(m + j, 2 * m));
idxset.push_back(std::make_pair(m + j, m + n - 1));
}
// calculate dot products
std::vector<double> tmp(idxset.size());
for (size_t i = 0; i < tmp.size(); ++i) {
tmp[i] = hist.CalcDot(idxset[i].first, idxset[i].second);
}
rabit::Allreduce<rabit::op::Sum>(BeginPtr(tmp), tmp.size());
for (size_t i = 0; i < tmp.size(); ++i) {
gstate.DotBuf(idxset[i].first, idxset[i].second) = tmp[i];
}
// BFGS steps, use vector-free update
// parameterize vector using basis in hist
std::vector<double> alpha(n);
@ -279,6 +284,7 @@ class LBFGSSolver {
double beta = vsum / gstate.DotBuf(j, m + j);
delta[j] = delta[j] + (alpha[j] - beta);
}
// set all to zero
std::fill(dir, dir + num_dim, 0.0f);
DType *dirsub = dir + range_begin_;
@ -291,6 +297,7 @@ class LBFGSSolver {
}
FixDirL1Sign(dirsub, hist[2 * m], nsub);
vdot = -Dot(dirsub, hist[2 * m], nsub);
// allreduce to get full direction
rabit::Allreduce<rabit::op::Sum>(dir, num_dim);
rabit::Allreduce<rabit::op::Sum>(&vdot, 1);
@ -482,6 +489,7 @@ class LBFGSSolver {
num_iteration = 0;
num_dim = 0;
old_objval = 0.0;
offset_ = 0;
}
~GlobalState(void) {
if (grad != NULL) {
@ -496,6 +504,10 @@ class LBFGSSolver {
data.resize(n * n, 0.0);
this->AllocSpace();
}
// memory cost
inline size_t MemCost(void) const {
return sizeof(DType) * 3 * num_dim;
}
inline double &DotBuf(size_t i, size_t j) {
if (i > j) std::swap(i, j);
return data[MapIndex(i, offset_, size_memory) * (size_memory * 2 + 1) +
@ -565,6 +577,10 @@ class LBFGSSolver {
size_t n = size_memory * 2 + 1;
dptr_ = new DType[n * stride_];
}
// memory cost
inline size_t MemCost(void) const {
return sizeof(DType) * (size_memory_ * 2 + 1) * stride_;
}
// fetch element from rolling array
inline const DType *operator[](size_t i) const {
return dptr_ + MapIndex(i, offset_, size_memory_) * stride_;

View File

@ -82,6 +82,10 @@ struct SparseMat {
inline size_t NumRow(void) const {
return row_ptr.size() - 1;
}
// memory cost
inline size_t MemCost(void) const {
return data.size() * sizeof(Entry);
}
// maximum feature dimension
size_t feat_dim;
std::vector<size_t> row_ptr;

View File

@ -26,6 +26,9 @@ AllreduceBase::AllreduceBase(void) {
world_size = -1;
hadoop_mode = 0;
version_number = 0;
// 32 K items
reduce_ring_mincount = 32 << 10;
// tracker URL
task_id = "NULL";
err_link = NULL;
this->SetParam("rabit_reduce_buffer", "256MB");
@ -33,6 +36,7 @@ AllreduceBase::AllreduceBase(void) {
env_vars.push_back("rabit_task_id");
env_vars.push_back("rabit_num_trial");
env_vars.push_back("rabit_reduce_buffer");
env_vars.push_back("rabit_reduce_ring_mincount");
env_vars.push_back("rabit_tracker_uri");
env_vars.push_back("rabit_tracker_port");
}
@ -116,6 +120,27 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
tracker.SendStr(msg);
tracker.Close();
}
// util to parse data with unit suffix
inline size_t ParseUnit(const char *name, const char *val) {
char unit;
uint64_t amount;
int n = sscanf(val, "%lu%c", &amount, &unit);
if (n == 2) {
switch (unit) {
case 'B': return amount;
case 'K': return amount << 10UL;
case 'M': return amount << 20UL;
case 'G': return amount << 30UL;
default: utils::Error("invalid format for %s", name); return 0;
}
} else if (n == 1) {
return amount;
} else {
utils::Error("invalid format for %s," \
"shhould be {integer}{unit}, unit can be {B, KB, MB, GB}", name);
return 0;
}
}
/*!
* \brief set parameters to the engine
* \param name parameter name
@ -127,21 +152,11 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
if (!strcmp(name, "rabit_task_id")) task_id = val;
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val);
if (!strcmp(name, "rabit_reduce_ring_mincount")) {
reduce_ring_mincount = ParseUnit(name, val);
}
if (!strcmp(name, "rabit_reduce_buffer")) {
char unit;
uint64_t amount;
if (sscanf(val, "%lu%c", &amount, &unit) == 2) {
switch (unit) {
case 'B': reduce_buffer_size = (amount + 7)/ 8; break;
case 'K': reduce_buffer_size = amount << 7UL; break;
case 'M': reduce_buffer_size = amount << 17UL; break;
case 'G': reduce_buffer_size = amount << 27UL; break;
default: utils::Error("invalid format for reduce buffer");
}
} else {
utils::Error("invalid format for reduce_buffer,"\
"shhould be {integer}{unit}, unit can be {B, KB, MB, GB}");
}
reduce_buffer_size = (ParseUnit(name, val) + 7) >> 3;
}
}
/*!
@ -341,6 +356,28 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
if (count > reduce_ring_mincount) {
return this->TryAllreduceRing(sendrecvbuf_, type_nbytes, count, reducer);
} else {
return this->TryAllreduceTree(sendrecvbuf_, type_nbytes, count, reducer);
}
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf,
* this function implements tree-shape reduction
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
AllreduceBase::ReturnType
AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
RefLinkVector &links = tree_links;
if (links.size() == 0 || count == 0) return kSuccess;
// total size of message
@ -411,7 +448,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && selecter.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToRingBuffer(size_up_out);
ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
}
@ -599,5 +636,217 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
}
return kSuccess;
}
/*!
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
*/
AllreduceBase::ReturnType
AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice) {
// read from next link and send to prev one
LinkRecord &prev = *ring_prev, &next = *ring_next;
// need to reply on special rank structure
utils::Assert(next.rank == (rank + 1) % world_size &&
rank == (prev.rank + 1) % world_size,
"need to assume rank structure");
// send recv buffer
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
const size_t stop_read = total_size + slice_begin;
const size_t stop_write = total_size + slice_begin - size_prev_slice;
size_t write_ptr = slice_begin;
size_t read_ptr = slice_end;
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
if (read_ptr != stop_read) {
selecter.WatchRead(next.sock);
finished = false;
}
if (write_ptr != stop_write) {
if (write_ptr < read_ptr) {
selecter.WatchWrite(prev.sock);
}
finished = false;
}
if (finished) break;
selecter.Select();
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
size_t size = stop_read - read_ptr;
size_t start = read_ptr % total_size;
if (start + size > total_size) {
size = total_size - start;
}
ssize_t len = next.sock.Recv(sendrecvbuf + start, size);
if (len != -1) {
read_ptr += static_cast<size_t>(len);
} else {
ReturnType ret = Errno2Return(errno);
if (ret != kSuccess) return ReportError(&next, ret);
}
}
if (write_ptr < read_ptr && write_ptr != stop_write) {
size_t size = std::min(read_ptr, stop_write) - write_ptr;
size_t start = write_ptr % total_size;
if (start + size > total_size) {
size = total_size - start;
}
ssize_t len = prev.sock.Send(sendrecvbuf + start, size);
if (len != -1) {
write_ptr += static_cast<size_t>(len);
} else {
ReturnType ret = Errno2Return(errno);
if (ret != kSuccess) return ReportError(&prev, ret);
}
}
}
return kSuccess;
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail,
* and will return the cause of failure
*
* Ring-based algorithm
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType, TryAllreduce
*/
AllreduceBase::ReturnType
AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
// read from next link and send to prev one
LinkRecord &prev = *ring_prev, &next = *ring_next;
// need to reply on special rank structure
utils::Assert(next.rank == (rank + 1) % world_size &&
rank == (prev.rank + 1) % world_size,
"need to assume rank structure");
// total size of message
const size_t total_size = type_nbytes * count;
size_t n = static_cast<size_t>(world_size);
size_t step = (count + n - 1) / n;
size_t r = static_cast<size_t>(next.rank);
size_t write_ptr = std::min(r * step, count) * type_nbytes;
size_t read_ptr = std::min((r + 1) * step, count) * type_nbytes;
size_t reduce_ptr = read_ptr;
// send recv buffer
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
// position to stop reading
const size_t stop_read = total_size + write_ptr;
// position to stop writing
size_t stop_write = total_size + std::min(rank * step, count) * type_nbytes;
if (stop_write > stop_read) {
stop_write -= total_size;
utils::Assert(write_ptr <= stop_write, "write ptr boundary check");
}
// use ring buffer in next position
next.InitBuffer(type_nbytes, step, reduce_buffer_size);
// set size_read to read pointer for ring buffer to work properly
next.size_read = read_ptr;
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
if (read_ptr != stop_read) {
selecter.WatchRead(next.sock);
finished = false;
}
if (write_ptr != stop_write) {
if (write_ptr < reduce_ptr) {
selecter.WatchWrite(prev.sock);
}
finished = false;
}
if (finished) break;
selecter.Select();
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
if (ret != kSuccess) {
return ReportError(&next, ret);
}
// sync the rate
read_ptr = next.size_read;
utils::Assert(read_ptr <= stop_read, "[%d] read_ptr boundary check", rank);
const size_t buffer_size = next.buffer_size;
size_t max_reduce = (read_ptr / type_nbytes) * type_nbytes;
while (reduce_ptr < max_reduce) {
size_t bstart = reduce_ptr % buffer_size;
size_t nread = std::min(buffer_size - bstart,
max_reduce - reduce_ptr);
size_t rstart = reduce_ptr % total_size;
nread = std::min(nread, total_size - rstart);
reducer(next.buffer_head + bstart,
sendrecvbuf + rstart,
static_cast<int>(nread / type_nbytes),
MPI::Datatype(type_nbytes));
reduce_ptr += nread;
}
}
if (write_ptr < reduce_ptr && write_ptr != stop_write) {
size_t size = std::min(reduce_ptr, stop_write) - write_ptr;
size_t start = write_ptr % total_size;
if (start + size > total_size) {
size = total_size - start;
}
ssize_t len = prev.sock.Send(sendrecvbuf + start, size);
if (len != -1) {
write_ptr += static_cast<size_t>(len);
} else {
ReturnType ret = Errno2Return(errno);
if (ret != kSuccess) return ReportError(&prev, ret);
}
}
}
return kSuccess;
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
AllreduceBase::ReturnType
AllreduceBase::TryAllreduceRing(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
ReturnType ret = TryReduceScatterRing(sendrecvbuf_, type_nbytes, count, reducer);
if (ret != kSuccess) return ret;
size_t n = static_cast<size_t>(world_size);
size_t step = (count + n - 1) / n;
size_t begin = std::min(rank * step, count) * type_nbytes;
size_t end = std::min((rank + 1) * step, count) * type_nbytes;
// previous rank
int prank = ring_prev->rank;
// get rank of previous
return TryAllgatherRing
(sendrecvbuf_, type_nbytes * count,
begin, end,
(std::min((prank + 1) * step, count) -
std::min(prank * step, count)) * type_nbytes);
}
} // namespace engine
} // namespace rabit

View File

@ -279,14 +279,18 @@ class AllreduceBase : public IEngine {
* position after protect_start
* \param protect_start all data start from protect_start is still needed in buffer
* read shall not override this
* \param max_size_read maximum logical amount we can read, size_read cannot exceed this value
* \return the type of reading
*/
inline ReturnType ReadToRingBuffer(size_t protect_start) {
inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) {
utils::Assert(buffer_head != NULL, "ReadToRingBuffer: buffer not allocated");
utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check");
size_t ngap = size_read - protect_start;
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
size_t offset = size_read % buffer_size;
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
size_t nmax = max_size_read - size_read;
nmax = std::min(nmax, buffer_size - ngap);
nmax = std::min(nmax, buffer_size - offset);
if (nmax == 0) return kSuccess;
ssize_t len = sock.Recv(buffer_head + offset, nmax);
// length equals 0, remote disconnected
@ -380,13 +384,79 @@ class AllreduceBase : public IEngine {
ReduceFunction reducer);
/*!
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
* \param sendrecvbuf_ buffer for both sending and recving data
* \param sendrecvbuf_ buffer for both sending and receiving data
* \param size the size of the data to be broadcasted
* \param root the root worker id to broadcast the data
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
/*!
* \brief perform in-place allreduce, on sendrecvbuf,
* this function implements tree-shape reduction
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryAllreduceTree(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer);
/*!
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
size_t slice_begin, size_t slice_end,
size_t size_prev_slice);
/*!
* \brief perform in-place allreduce, reduce on the sendrecvbuf,
*
* after the function, node k get k-th segment of the reduction result
* the k-th segment is defined by [k * step, min((k + 1) * step,count) )
* where step = ceil(count / world_size)
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType, TryAllreduce
*/
ReturnType TryReduceScatterRing(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer);
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* use a ring based algorithm, reduce-scatter + allgather
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryAllreduceRing(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer);
/*!
* \brief function used to report error when a link goes wrong
* \param link the pointer to the link who causes the error
@ -432,6 +502,10 @@ class AllreduceBase : public IEngine {
int slave_port, nport_trial;
// reduce buffer size
size_t reduce_buffer_size;
// reduction method
int reduce_method;
// mininum count of cells to use ring based method
size_t reduce_ring_mincount;
// current rank
int rank;
// world size

View File

@ -81,18 +81,18 @@ class AllreduceMock : public AllreduceRobust {
ComboSerializer com(global_model, local_model);
AllreduceRobust::CheckPoint(&dum, &com);
}
tsum_allreduce = 0.0;
time_checkpoint = utils::GetTime();
double tcost = utils::GetTime() - tstart;
if (report_stats != 0 && rank == 0) {
std::stringstream ss;
ss << "[v" << version_number << "] global_size=" << global_checkpoint.length()
<< "local_size=" << local_chkpt[local_chkpt_version].length()
<< "check_tcost="<< tcost <<" sec,"
<< "allreduce_tcost=" << tsum_allreduce << " sec,"
<< "between_chpt=" << tbet_chkpt << "sec\n";
<< ",local_size=" << (local_chkpt[0].length() + local_chkpt[1].length())
<< ",check_tcost="<< tcost <<" sec"
<< ",allreduce_tcost=" << tsum_allreduce << " sec"
<< ",between_chpt=" << tbet_chkpt << "sec\n";
this->TrackerPrint(ss.str());
}
tsum_allreduce = 0.0;
}
virtual void LazyCheckPoint(const ISerializable *global_model) {

View File

@ -677,7 +677,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
}
utils::Assert(min_write <= links[pid].size_read, "boundary check");
ReturnType ret = links[pid].ReadToRingBuffer(min_write);
ReturnType ret = links[pid].ReadToRingBuffer(min_write, size);
if (ret != kSuccess) {
return ReportError(&links[pid], ret);
}

View File

@ -287,7 +287,6 @@ class AllreduceRobust : public AllreduceBase {
if (seqno_.size() == 0) return -1;
return seqno_.back();
}
private:
// sequence number of each
std::vector<int> seqno_;

View File

@ -1,7 +1,7 @@
export CC = gcc
export CXX = g++
export MPICXX = mpicxx
export LDFLAGS= -pthread -lm -lrt -L../lib
export LDFLAGS= -L../lib -pthread -lm -lrt
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../include -std=c++11
# specify tensor path
@ -29,7 +29,7 @@ local_recover: local_recover.o $(RABIT_OBJ)
lazy_recover: lazy_recover.o $(RABIT_OBJ)
$(BIN) :
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit_mock
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) -lrabit_mock $(LDFLAGS)
$(OBJ) :
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )

View File

@ -24,3 +24,6 @@ lazy_recover_10_10k_die_hard:
lazy_recover_10_10k_die_same:
../tracker/rabit_demo.py -n 10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0
ringallreduce_10_10k:
../tracker/rabit_demo.py -v 1 -n 10 model_recover 100 rabit_reduce_ring_mincount=10

View File

@ -9,4 +9,4 @@ the example guidelines are in the script themselfs
* Yarn (Hadoop): [rabit_yarn.py](rabit_yarn.py)
- It is also possible to submit via hadoop streaming with rabit_hadoop_streaming.py
- However, it is higly recommended to use rabit_yarn.py because this will allocate resources more precisely and fits machine learning scenarios
* Sun Grid engine: [rabit_sge.py](rabit_sge.py)

View File

@ -1,7 +1,6 @@
#!/usr/bin/python
#!/usr/bin/env python
"""
This is the demo submission script of rabit, it is created to
submit rabit jobs using hadoop streaming
This is the demo submission script of rabit for submitting jobs in local machine
"""
import argparse
import sys
@ -43,7 +42,7 @@ def exec_cmd(cmd, taskid, worker_env):
if cmd[0].find('/') == -1 and os.path.exists(cmd[0]) and os.name != 'nt':
cmd[0] = './' + cmd[0]
cmd = ' '.join(cmd)
env = {}
env = os.environ.copy()
for k, v in worker_env.items():
env[k] = str(v)
env['rabit_task_id'] = str(taskid)

View File

@ -1,4 +1,4 @@
#!/usr/bin/python
#!/usr/bin/env python
"""
Deprecated

View File

@ -1,7 +1,6 @@
#!/usr/bin/python
#!/usr/bin/env python
"""
This is the demo submission script of rabit, it is created to
submit rabit jobs using hadoop streaming
Submission script to submit rabit jobs using MPI
"""
import argparse
import sys

View File

@ -0,0 +1,69 @@
#!/usr/bin/env python
"""
Submit rabit jobs to Sun Grid Engine
"""
import argparse
import sys
import os
import subprocess
import rabit_tracker as tracker
parser = argparse.ArgumentParser(description='Rabit script to submit rabit job using MPI')
parser.add_argument('-n', '--nworker', required=True, type=int,
help = 'number of worker proccess to be launched')
parser.add_argument('-q', '--queue', default='default', type=str,
help = 'the queue we want to submit the job to')
parser.add_argument('-hip', '--host_ip', default='auto', type=str,
help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine')
parser.add_argument('--vcores', default = 1, type=int,
help = 'number of vcpores to request in each mapper, set it if each rabit job is multi-threaded')
parser.add_argument('--jobname', default='auto', help = 'customize jobname in tracker')
parser.add_argument('--logdir', default='auto', help = 'customize the directory to place the logs')
parser.add_argument('-v', '--verbose', default=0, choices=[0, 1], type=int,
help = 'print more messages into the console')
parser.add_argument('command', nargs='+',
help = 'command for rabit program')
args = parser.parse_args()
if args.jobname == 'auto':
args.jobname = ('rabit%d.' % args.nworker) + args.command[0].split('/')[-1];
if args.logdir == 'auto':
args.logdir = args.jobname + '.log'
if os.path.exists(args.logdir):
if not os.path.isdir(args.logdir):
raise RuntimeError('specified logdir %s is a file instead of directory' % args.logdir)
else:
os.mkdir(args.logdir)
runscript = '%s/runrabit.sh' % args.logdir
fo = open(runscript, 'w')
fo.write('\"$@\"\n')
fo.close()
#
# submission script using MPI
#
def sge_submit(nslave, worker_args, worker_envs):
"""
customized submit script, that submit nslave jobs, each must contain args as parameter
note this can be a lambda function containing additional parameters in input
Parameters
nslave number of slave process to start up
args arguments to launch each job
this usually includes the parameters of master_uri and parameters passed into submit
"""
env_arg = ','.join(['%s=\"%s\"' % (k, str(v)) for k, v in worker_envs.items()])
cmd = 'qsub -cwd -t 1-%d -S /bin/bash' % nslave
if args.queue != 'default':
cmd += '-q %s' % args.queue
cmd += ' -N %s ' % args.jobname
cmd += ' -e %s -o %s' % (args.logdir, args.logdir)
cmd += ' -pe orte %d' % (args.vcores)
cmd += ' -v %s,PATH=${PATH}:.' % env_arg
cmd += ' %s %s' % (runscript, ' '.join(args.command + worker_args))
print cmd
subprocess.check_call(cmd, shell = True)
print 'Waiting for the jobs to get up...'
# call submit, with nslave, the commands to run each job and submit function
tracker.submit(args.nworker, [], fun_submit = sge_submit, verbose = args.verbose)

View File

@ -13,6 +13,7 @@ import socket
import struct
import subprocess
import random
import time
from threading import Thread
"""
@ -188,6 +189,7 @@ class Tracker:
vlst.reverse()
rlst += vlst
return rlst
def get_ring(self, tree_map, parent_map):
"""
get a ring connection used to recover local data
@ -202,14 +204,44 @@ class Tracker:
rnext = (r + 1) % nslave
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
return ring_map
def get_link_map(self, nslave):
"""
get the link map, this is a bit hacky, call for better algorithm
to place similar nodes together
"""
tree_map, parent_map = self.get_tree(nslave)
ring_map = self.get_ring(tree_map, parent_map)
rmap = {0 : 0}
k = 0
for i in range(nslave - 1):
k = ring_map[k][1]
rmap[k] = i + 1
ring_map_ = {}
tree_map_ = {}
parent_map_ ={}
for k, v in ring_map.items():
ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]])
for k, v in tree_map.items():
tree_map_[rmap[k]] = [rmap[x] for x in v]
for k, v in parent_map.items():
if k != 0:
parent_map_[rmap[k]] = rmap[v]
else:
parent_map_[rmap[k]] = -1
return tree_map_, parent_map_, ring_map_
def handle_print(self,slave, msg):
sys.stdout.write(msg)
def log_print(self, msg, level):
if level == 1:
if self.verbose:
sys.stderr.write(msg + '\n')
else:
sys.stderr.write(msg + '\n')
def accept_slaves(self, nslave):
# set of nodes that finishs the job
shutdown = {}
@ -241,31 +273,40 @@ class Tracker:
assert s.cmd == 'start'
if s.world_size > 0:
nslave = s.world_size
tree_map, parent_map = self.get_tree(nslave)
ring_map = self.get_ring(tree_map, parent_map)
tree_map, parent_map, ring_map = self.get_link_map(nslave)
# set of nodes that is pending for getting up
todo_nodes = range(nslave)
random.shuffle(todo_nodes)
else:
assert s.world_size == -1 or s.world_size == nslave
if s.cmd == 'recover':
assert s.rank >= 0
rank = s.decide_rank(job_map)
# batch assignment of ranks
if rank == -1:
assert len(todo_nodes) != 0
pending.append(s)
if len(pending) == len(todo_nodes):
pending.sort(key = lambda x : x.host)
for s in pending:
rank = todo_nodes.pop(0)
if s.jobid != 'NULL':
job_map[s.jobid] = rank
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
if s.wait_accept > 0:
wait_conn[rank] = s
self.log_print('Recieve %s signal from %s; assign rank %d' % (s.cmd, s.host, s.rank), 1)
if len(todo_nodes) == 0:
self.log_print('@tracker All of %d nodes getting started' % nslave, 2)
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
if s.cmd != 'start':
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank), 1)
self.start_time = time.time()
else:
self.log_print('Recieve %s signal from %s; assign rank %d' % (s.cmd, s.host, s.rank), 1)
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank), 1)
if s.wait_accept > 0:
wait_conn[rank] = s
self.log_print('@tracker All nodes finishes job', 2)
self.end_time = time.time()
self.log_print('@tracker %s secs between node start and job finish' % str(self.end_time - self.start_time), 2)
def submit(nslave, args, fun_submit, verbose, hostIP = 'auto'):
master = Tracker(verbose = verbose, hostIP = hostIP)

View File

@ -1,4 +1,4 @@
#!/usr/bin/python
#!/usr/bin/env python
"""
This is a script to submit rabit job via Yarn
rabit will run as a Yarn application
@ -13,6 +13,7 @@ import rabit_tracker as tracker
WRAPPER_PATH = os.path.dirname(__file__) + '/../wrapper'
YARN_JAR_PATH = os.path.dirname(__file__) + '/../yarn/rabit-yarn.jar'
YARN_BOOT_PY = os.path.dirname(__file__) + '/../yarn/run_hdfs_prog.py'
if not os.path.exists(YARN_JAR_PATH):
warnings.warn("cannot find \"%s\", I will try to run build" % YARN_JAR_PATH)
@ -21,7 +22,7 @@ if not os.path.exists(YARN_JAR_PATH):
subprocess.check_call(cmd, shell = True, env = os.environ)
assert os.path.exists(YARN_JAR_PATH), "failed to build rabit-yarn.jar, try it manually"
hadoop_binary = 'hadoop'
hadoop_binary = None
# code
hadoop_home = os.getenv('HADOOP_HOME')
@ -38,6 +39,8 @@ parser.add_argument('-hip', '--host_ip', default='auto', type=str,
help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine')
parser.add_argument('-v', '--verbose', default=0, choices=[0, 1], type=int,
help = 'print more messages into the console')
parser.add_argument('-q', '--queue', default='default', type=str,
help = 'the queue we want to submit the job to')
parser.add_argument('-ac', '--auto_file_cache', default=1, choices=[0, 1], type=int,
help = 'whether automatically cache the files in the command to hadoop localfile, this is on by default')
parser.add_argument('-f', '--files', default = [], action='append',
@ -56,6 +59,11 @@ parser.add_argument('-mem', '--memory_mb', default=1024, type=int,
help = 'maximum memory used by the process. Guide: set it large (near mapred.cluster.max.map.memory.mb)'\
'if you are running multi-threading rabit,'\
'so that each node can occupy all the mapper slots in a machine for maximum performance')
parser.add_argument('--libhdfs-opts', default='-Xmx128m', type=str,
help = 'setting to be passed to libhdfs')
parser.add_argument('--name-node', default='default', type=str,
help = 'the namenode address of hdfs, libhdfs should connect to, normally leave it as default')
parser.add_argument('command', nargs='+',
help = 'command for rabit program')
args = parser.parse_args()
@ -87,7 +95,7 @@ if hadoop_version < 2:
print 'Current Hadoop Version is %s, rabit_yarn will need Yarn(Hadoop 2.0)' % out[1]
def submit_yarn(nworker, worker_args, worker_env):
fset = set([YARN_JAR_PATH])
fset = set([YARN_JAR_PATH, YARN_BOOT_PY])
if args.auto_file_cache != 0:
for i in range(len(args.command)):
f = args.command[i]
@ -96,7 +104,7 @@ def submit_yarn(nworker, worker_args, worker_env):
if i == 0:
args.command[i] = './' + args.command[i].split('/')[-1]
else:
args.command[i] = args.command[i].split('/')[-1]
args.command[i] = './' + args.command[i].split('/')[-1]
if args.command[0].endswith('.py'):
flst = [WRAPPER_PATH + '/rabit.py',
WRAPPER_PATH + '/librabit_wrapper.so',
@ -112,6 +120,8 @@ def submit_yarn(nworker, worker_args, worker_env):
env['rabit_cpu_vcores'] = str(args.vcores)
env['rabit_memory_mb'] = str(args.memory_mb)
env['rabit_world_size'] = str(args.nworker)
env['rabit_hdfs_opts'] = str(args.libhdfs_opts)
env['rabit_hdfs_namenode'] = str(args.name_node)
if args.files != None:
for flst in args.files:
@ -121,7 +131,8 @@ def submit_yarn(nworker, worker_args, worker_env):
cmd += ' -file %s' % f
cmd += ' -jobname %s ' % args.jobname
cmd += ' -tempdir %s ' % args.tempdir
cmd += (' '.join(args.command + worker_args))
cmd += ' -queue %s ' % args.queue
cmd += (' '.join(['./run_hdfs_prog.py'] + args.command + worker_args))
if args.verbose != 0:
print cmd
subprocess.check_call(cmd, shell = True, env = env)

View File

@ -1 +0,0 @@
foler used to hold generated class files

View File

@ -1,8 +1,8 @@
#!/bin/bash
if [ -z "$HADOOP_PREFIX" ]; then
echo "cannot found $HADOOP_PREFIX in the environment variable, please set it properly"
exit 1
if [ ! -d bin ]; then
mkdir bin
fi
CPATH=`${HADOOP_PREFIX}/bin/hadoop classpath`
CPATH=`${HADOOP_HOME}/bin/hadoop classpath`
javac -cp $CPATH -d bin src/org/apache/hadoop/yarn/rabit/*
jar cf rabit-yarn.jar -C bin .

View File

@ -0,0 +1,45 @@
#!/usr/bin/env python
"""
this script helps setup classpath env for HDFS, before running program
that links with libhdfs
"""
import glob
import sys
import os
import subprocess
if len(sys.argv) < 2:
print 'Usage: the command you want to run'
hadoop_home = os.getenv('HADOOP_HOME')
hdfs_home = os.getenv('HADOOP_HDFS_HOME')
java_home = os.getenv('JAVA_HOME')
if hadoop_home is None:
hadoop_home = os.getenv('HADOOP_PREFIX')
assert hadoop_home is not None, 'need to set HADOOP_HOME'
assert hdfs_home is not None, 'need to set HADOOP_HDFS_HOME'
assert java_home is not None, 'need to set JAVA_HOME'
(classpath, err) = subprocess.Popen('%s/bin/hadoop classpath' % hadoop_home,
stdout=subprocess.PIPE, shell = True,
env = os.environ).communicate()
cpath = []
for f in classpath.split(':'):
cpath += glob.glob(f)
lpath = []
lpath.append('%s/lib/native' % hdfs_home)
lpath.append('%s/jre/lib/amd64/server' % java_home)
env = os.environ.copy()
env['CLASSPATH'] = '${CLASSPATH}:' + (':'.join(cpath))
# setup hdfs options
if 'rabit_hdfs_opts' in env:
env['LIBHDFS_OPTS'] = env['rabit_hdfs_opts']
elif 'LIBHDFS_OPTS' not in env:
env['LIBHDFS_OPTS'] = '--Xmx128m'
env['LD_LIBRARY_PATH'] = '${LD_LIBRARY_PATH}:' + (':'.join(lpath))
ret = subprocess.call(args = sys.argv[1:], env = env)
sys.exit(ret)

View File

@ -1,5 +1,6 @@
package org.apache.hadoop.yarn.rabit;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
@ -7,6 +8,7 @@ import java.util.Map;
import java.util.Queue;
import java.util.Collection;
import java.util.Collections;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
@ -34,6 +36,7 @@ import org.apache.hadoop.yarn.api.records.NodeReport;
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest;
import org.apache.hadoop.yarn.client.api.async.NMClientAsync;
import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
import org.apache.hadoop.security.UserGroupInformation;
/**
* application master for allocating resources of rabit client
@ -61,6 +64,8 @@ public class ApplicationMaster {
// command to launch
private String command = "";
// username
private String userName = "";
// application tracker hostname
private String appHostName = "";
// tracker URL to do
@ -128,6 +133,8 @@ public class ApplicationMaster {
*/
private void initArgs(String args[]) throws IOException {
LOG.info("Invoke initArgs");
// get user name
userName = UserGroupInformation.getCurrentUser().getShortUserName();
// cached maps
Map<String, Path> cacheFiles = new java.util.HashMap<String, Path>();
for (int i = 0; i < args.length; ++i) {
@ -156,7 +163,8 @@ public class ApplicationMaster {
numVCores = this.getEnvInteger("rabit_cpu_vcores", true, numVCores);
numMemoryMB = this.getEnvInteger("rabit_memory_mb", true, numMemoryMB);
numTasks = this.getEnvInteger("rabit_world_size", true, numTasks);
maxNumAttempt = this.getEnvInteger("rabit_max_attempt", false, maxNumAttempt);
maxNumAttempt = this.getEnvInteger("rabit_max_attempt", false,
maxNumAttempt);
}
/**
@ -208,10 +216,10 @@ public class ApplicationMaster {
assert (killedTasks.size() + finishedTasks.size() == numTasks);
success = finishedTasks.size() == numTasks;
LOG.info("Application completed. Stopping running containers");
nmClient.stop();
diagnostics = "Diagnostics." + ", num_tasks" + this.numTasks
+ ", finished=" + this.finishedTasks.size() + ", failed="
+ this.killedTasks.size() + "\n" + this.abortDiagnosis;
nmClient.stop();
LOG.info(diagnostics);
} catch (Exception e) {
diagnostics = e.toString();
@ -220,7 +228,8 @@ public class ApplicationMaster {
success ? FinalApplicationStatus.SUCCEEDED
: FinalApplicationStatus.FAILED, diagnostics,
appTrackerUrl);
if (!success) throw new Exception("Application not successful");
if (!success)
throw new Exception("Application not successful");
}
/**
@ -265,14 +274,19 @@ public class ApplicationMaster {
task.containerRequest = null;
ContainerLaunchContext ctx = Records
.newRecord(ContainerLaunchContext.class);
String hadoop = "hadoop";
if (System.getenv("HADOOP_HOME") != null) {
hadoop = "${HADOOP_HOME}/bin/hadoop";
} else if (System.getenv("HADOOP_PREFIX") != null) {
hadoop = "${HADOOP_PREFIX}/bin/hadoop";
}
String cmd =
// use this to setup CLASSPATH correctly for libhdfs
"CLASSPATH=${CLASSPATH}:`${HADOOP_PREFIX}/bin/hadoop classpath --glob` "
+ this.command + " 1>"
this.command + " 1>"
+ ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
+ " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR
+ "/stderr";
LOG.info(cmd);
ctx.setCommands(Collections.singletonList(cmd));
LOG.info(workerResources);
ctx.setLocalResources(this.workerResources);
@ -284,11 +298,39 @@ public class ApplicationMaster {
for (String c : conf.getStrings(
YarnConfiguration.YARN_APPLICATION_CLASSPATH,
YarnConfiguration.DEFAULT_YARN_APPLICATION_CLASSPATH)) {
cpath.append(':');
cpath.append(c.trim());
String[] arrPath = c.split(":");
for (String ps : arrPath) {
if (ps.endsWith("*.jar") || ps.endsWith("*")) {
ps = ps.substring(0, ps.lastIndexOf('*'));
String prefix = ps.substring(0, ps.lastIndexOf('/'));
if (ps.startsWith("$")) {
String[] arr =ps.split("/", 2);
if (arr.length != 2) continue;
try {
ps = System.getenv(arr[0].substring(1)) + '/' + arr[1];
} catch (Exception e){
continue;
}
// already use hadoop command to get class path in worker, maybe a better solution in future
// env.put("CLASSPATH", cpath.toString());
}
File dir = new File(ps);
if (dir.isDirectory()) {
for (File f: dir.listFiles()) {
if (f.isFile() && f.getPath().endsWith(".jar")) {
cpath.append(":");
cpath.append(prefix + '/' + f.getName());
}
}
}
} else {
cpath.append(':');
cpath.append(ps.trim());
}
}
}
// already use hadoop command to get class path in worker, maybe a
// better solution in future
env.put("CLASSPATH", cpath.toString());
//LOG.info("CLASSPATH =" + cpath.toString());
// setup LD_LIBARY_PATH path for libhdfs
env.put("LD_LIBRARY_PATH",
"${LD_LIBRARY_PATH}:$HADOOP_HDFS_HOME/lib/native:$JAVA_HOME/jre/lib/amd64/server");
@ -298,10 +340,13 @@ public class ApplicationMaster {
if (e.getKey().startsWith("rabit_")) {
env.put(e.getKey(), e.getValue());
}
if (e.getKey() == "LIBHDFS_OPTS") {
env.put(e.getKey(), e.getValue());
}
}
env.put("rabit_task_id", String.valueOf(task.taskId));
env.put("rabit_num_trial", String.valueOf(task.attemptCounter));
// ctx.setUser(userName);
ctx.setEnvironment(env);
synchronized (this) {
assert (!this.runningTasks.containsKey(container.getId()));
@ -376,8 +421,17 @@ public class ApplicationMaster {
Collection<TaskRecord> tasks = new java.util.LinkedList<TaskRecord>();
for (ContainerId cid : failed) {
TaskRecord r = runningTasks.remove(cid);
if (r == null)
if (r == null) {
continue;
}
LOG.info("Task "
+ r.taskId
+ "failed on "
+ r.container.getId()
+ ". See LOG at : "
+ String.format("http://%s/node/containerlogs/%s/"
+ userName, r.container.getNodeHttpAddress(),
r.container.getId()));
r.attemptCounter += 1;
r.container = null;
tasks.add(r);
@ -411,22 +465,26 @@ public class ApplicationMaster {
finishedTasks.add(r);
runningTasks.remove(s.getContainerId());
} else {
switch (exstatus) {
case ContainerExitStatus.KILLED_EXCEEDED_PMEM:
try {
if (exstatus == ContainerExitStatus.class.getField(
"KILLED_EXCEEDED_PMEM").getInt(null)) {
this.abortJob("[Rabit] Task "
+ r.taskId
+ " killed because of exceeding allocated physical memory");
break;
case ContainerExitStatus.KILLED_EXCEEDED_VMEM:
continue;
}
if (exstatus == ContainerExitStatus.class.getField(
"KILLED_EXCEEDED_VMEM").getInt(null)) {
this.abortJob("[Rabit] Task "
+ r.taskId
+ " killed because of exceeding allocated virtual memory");
break;
default:
LOG.info("[Rabit] Task " + r.taskId
+ " exited with status " + exstatus);
failed.add(s.getContainerId());
continue;
}
} catch (Exception e) {
}
LOG.info("[Rabit] Task " + r.taskId + " exited with status "
+ exstatus + " Diagnostics:"+ s.getDiagnostics());
failed.add(s.getContainerId());
}
}
this.handleFailure(failed);

View File

@ -9,6 +9,7 @@ import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.permission.FsPermission;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.yarn.api.ApplicationConstants;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.ApplicationReport;
@ -19,6 +20,7 @@ import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.api.records.LocalResourceType;
import org.apache.hadoop.yarn.api.records.LocalResourceVisibility;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.api.records.QueueInfo;
import org.apache.hadoop.yarn.api.records.YarnApplicationState;
import org.apache.hadoop.yarn.client.api.YarnClient;
import org.apache.hadoop.yarn.client.api.YarnClientApplication;
@ -43,14 +45,21 @@ public class Client {
private String appArgs = "";
// HDFS Path to store temporal result
private String tempdir = "/tmp";
// user name
private String userName = "";
// job name
private String jobName = "";
// queue
private String queue = "default";
/**
* constructor
* @throws IOException
*/
private Client() throws IOException {
conf.addResource(new Path(System.getenv("HADOOP_CONF_DIR") +"/core-site.xml"));
conf.addResource(new Path(System.getenv("HADOOP_CONF_DIR") +"/hdfs-site.xml"));
dfs = FileSystem.get(conf);
userName = UserGroupInformation.getCurrentUser().getShortUserName();
}
/**
@ -127,6 +136,9 @@ public class Client {
if (e.getKey().startsWith("rabit_")) {
env.put(e.getKey(), e.getValue());
}
if (e.getKey() == "LIBHDFS_OPTS") {
env.put(e.getKey(), e.getValue());
}
}
LOG.debug(env);
return env;
@ -152,6 +164,8 @@ public class Client {
this.jobName = args[++i];
} else if(args[i].equals("-tempdir")) {
this.tempdir = args[++i];
} else if(args[i].equals("-queue")) {
this.queue = args[++i];
} else {
sargs.append(" ");
sargs.append(args[i]);
@ -168,7 +182,6 @@ public class Client {
}
this.initArgs(args);
// Create yarnClient
YarnConfiguration conf = new YarnConfiguration();
YarnClient yarnClient = YarnClient.createYarnClient();
yarnClient.init(conf);
yarnClient.start();
@ -181,13 +194,14 @@ public class Client {
.newRecord(ContainerLaunchContext.class);
ApplicationSubmissionContext appContext = app
.getApplicationSubmissionContext();
// Submit application
ApplicationId appId = appContext.getApplicationId();
// setup cache-files and environment variables
amContainer.setLocalResources(this.setupCacheFiles(appId));
amContainer.setEnvironment(this.getEnvironment());
String cmd = "$JAVA_HOME/bin/java"
+ " -Xmx256M"
+ " -Xmx900M"
+ " org.apache.hadoop.yarn.rabit.ApplicationMaster"
+ this.cacheFileArg + ' ' + this.appArgs + " 1>"
+ ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
@ -197,15 +211,15 @@ public class Client {
// Set up resource type requirements for ApplicationMaster
Resource capability = Records.newRecord(Resource.class);
capability.setMemory(256);
capability.setMemory(1024);
capability.setVirtualCores(1);
LOG.info("jobname=" + this.jobName);
LOG.info("jobname=" + this.jobName + ",username=" + this.userName);
appContext.setApplicationName(jobName + ":RABIT-YARN");
appContext.setAMContainerSpec(amContainer);
appContext.setResource(capability);
appContext.setQueue("default");
appContext.setQueue(queue);
//appContext.setUser(userName);
LOG.info("Submitting application " + appId);
yarnClient.submitApplication(appContext);
@ -224,6 +238,10 @@ public class Client {
if (!appReport.getFinalApplicationStatus().equals(
FinalApplicationStatus.SUCCEEDED)) {
System.err.println(appReport.getDiagnostics());
System.out.println("Available queues:");
for (QueueInfo q : yarnClient.getAllQueues()) {
System.out.println(q.getQueueName());
}
}
}