Squashed 'subtree/rabit/' changes from 091634b..59e63bc
59e63bc minor 6233050 ok 14477f9 add namenode 75a6d34 add libhdfs opts e3c76bf minmum fix 8b3c435 chg 2035799 test code 7751b2b add debug 7690313 ok bd346b4 ok faba1dc add testload 6f7783e add testload e5f0340 ok 3ed9ec8 chg e552ac4 ask for more ram in am b2505e3 only stop nm when sucess bc696c9 add queue info f3e867e add option queue 5dc843c refactor fileio cd9c81b quick fix 1e23af2 add virtual destructor to iseekstream f165ffb fix hdfs 8cc6508 allow demo to pass in env fad4d69 ok 0fd6197 fix more 7423837 fix more d25de54 add temporal solution, run_yarn_prog.py e5a9e31 final attempt ed3bee8 add command back 0774000 add hdfs to resource 9b66e7e fix hadoop 6812f14 ok 08e1c16 change hadoop prefix back to hadoop home d6b6828 Update build.sh 146e069 bugfix: logical boundary for ring buffer 19cb685 ok 4cf3c13 Merge branch 'master' of ssh://github.com/tqchen/rabit 20daddb add tracker c57dad8 add ringbased passing and batch schedule 295d8a1 update 994cb02 add sge 014c866 OK git-subtree-dir: subtree/rabit git-subtree-split: 59e63bc1354c9ff516d72d9a6468f6c431627202
This commit is contained in:
parent
13a319ca01
commit
75bf97b575
@ -19,6 +19,8 @@ namespace utils {
|
||||
/*! \brief interface of i/o stream that support seek */
|
||||
class ISeekStream: public IStream {
|
||||
public:
|
||||
// virtual destructor
|
||||
virtual ~ISeekStream(void) {}
|
||||
/*! \brief seek to certain position of the file */
|
||||
virtual void Seek(size_t pos) = 0;
|
||||
/*! \brief tell the position of the stream */
|
||||
|
||||
2
rabit-learn/.gitignore
vendored
Normal file
2
rabit-learn/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
config.mk
|
||||
*.log
|
||||
@ -38,6 +38,7 @@ class StreamBufferReader {
|
||||
}
|
||||
}
|
||||
}
|
||||
/*! \brief whether we are reaching the end of file */
|
||||
inline bool AtEnd(void) const {
|
||||
return read_len_ == 0;
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@ class FileStream : public utils::ISeekStream {
|
||||
public:
|
||||
explicit FileStream(const char *fname, const char *mode)
|
||||
: use_stdio(false) {
|
||||
using namespace std;
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
if (!strcmp(fname, "stdin")) {
|
||||
use_stdio = true; fp = stdin;
|
||||
@ -51,7 +52,7 @@ class FileStream : public utils::ISeekStream {
|
||||
return std::ftell(fp);
|
||||
}
|
||||
virtual bool AtEnd(void) const {
|
||||
return feof(fp) != 0;
|
||||
return std::feof(fp) != 0;
|
||||
}
|
||||
inline void Close(void) {
|
||||
if (fp != NULL && !use_stdio) {
|
||||
@ -60,45 +61,50 @@ class FileStream : public utils::ISeekStream {
|
||||
}
|
||||
|
||||
private:
|
||||
FILE *fp;
|
||||
std::FILE *fp;
|
||||
bool use_stdio;
|
||||
};
|
||||
|
||||
/*! \brief line split from normal file system */
|
||||
class FileSplit : public LineSplitBase {
|
||||
class FileProvider : public LineSplitter::IFileProvider {
|
||||
public:
|
||||
explicit FileSplit(const char *uri, unsigned rank, unsigned nsplit) {
|
||||
LineSplitBase::SplitNames(&fnames_, uri, "#");
|
||||
explicit FileProvider(const char *uri) {
|
||||
LineSplitter::SplitNames(&fnames_, uri, "#");
|
||||
std::vector<size_t> fsize;
|
||||
for (size_t i = 0; i < fnames_.size(); ++i) {
|
||||
if (!strncmp(fnames_[i].c_str(), "file://", 7)) {
|
||||
if (!std::strncmp(fnames_[i].c_str(), "file://", 7)) {
|
||||
std::string tmp = fnames_[i].c_str() + 7;
|
||||
fnames_[i] = tmp;
|
||||
}
|
||||
fsize.push_back(GetFileSize(fnames_[i].c_str()));
|
||||
size_t fz = GetFileSize(fnames_[i].c_str());
|
||||
if (fz != 0) {
|
||||
fsize_.push_back(fz);
|
||||
}
|
||||
}
|
||||
LineSplitBase::Init(fsize, rank, nsplit);
|
||||
}
|
||||
virtual ~FileSplit(void) {}
|
||||
|
||||
protected:
|
||||
virtual utils::ISeekStream *GetFile(size_t file_index) {
|
||||
// destrucor
|
||||
virtual ~FileProvider(void) {}
|
||||
virtual utils::ISeekStream *Open(size_t file_index) {
|
||||
utils::Assert(file_index < fnames_.size(), "file index exceed bound");
|
||||
return new FileStream(fnames_[file_index].c_str(), "rb");
|
||||
}
|
||||
virtual const std::vector<size_t> &FileSize(void) const {
|
||||
return fsize_;
|
||||
}
|
||||
private:
|
||||
// file sizes
|
||||
std::vector<size_t> fsize_;
|
||||
// file names
|
||||
std::vector<std::string> fnames_;
|
||||
// get file size
|
||||
inline static size_t GetFileSize(const char *fname) {
|
||||
FILE *fp = utils::FopenCheck(fname, "rb");
|
||||
std::FILE *fp = utils::FopenCheck(fname, "rb");
|
||||
// NOTE: fseek may not be good, but serves as ok solution
|
||||
fseek(fp, 0, SEEK_END);
|
||||
size_t fsize = static_cast<size_t>(ftell(fp));
|
||||
fclose(fp);
|
||||
std::fseek(fp, 0, SEEK_END);
|
||||
size_t fsize = static_cast<size_t>(std::ftell(fp));
|
||||
std::fclose(fp);
|
||||
return fsize;
|
||||
}
|
||||
|
||||
private:
|
||||
// file names
|
||||
std::vector<std::string> fnames_;
|
||||
};
|
||||
} // namespace io
|
||||
} // namespace rabit
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <hdfs.h>
|
||||
#include <errno.h>
|
||||
@ -15,11 +16,15 @@
|
||||
/*! \brief io interface */
|
||||
namespace rabit {
|
||||
namespace io {
|
||||
class HDFSStream : public utils::ISeekStream {
|
||||
class HDFSStream : public ISeekStream {
|
||||
public:
|
||||
HDFSStream(hdfsFS fs, const char *fname, const char *mode)
|
||||
: fs_(fs), at_end_(false) {
|
||||
int flag;
|
||||
HDFSStream(hdfsFS fs,
|
||||
const char *fname,
|
||||
const char *mode,
|
||||
bool disconnect_when_done)
|
||||
: fs_(fs), at_end_(false),
|
||||
disconnect_when_done_(disconnect_when_done) {
|
||||
int flag = 0;
|
||||
if (!strcmp(mode, "r")) {
|
||||
flag = O_RDONLY;
|
||||
} else if (!strcmp(mode, "w")) {
|
||||
@ -35,6 +40,9 @@ class HDFSStream : public utils::ISeekStream {
|
||||
}
|
||||
virtual ~HDFSStream(void) {
|
||||
this->Close();
|
||||
if (disconnect_when_done_) {
|
||||
utils::Check(hdfsDisconnect(fs_) == 0, "hdfsDisconnect error");
|
||||
}
|
||||
}
|
||||
virtual size_t Read(void *ptr, size_t size) {
|
||||
tSize nread = hdfsRead(fs_, fp_, ptr, size);
|
||||
@ -86,52 +94,69 @@ class HDFSStream : public utils::ISeekStream {
|
||||
}
|
||||
}
|
||||
|
||||
inline static std::string GetNameNode(void) {
|
||||
const char *nn = getenv("rabit_hdfs_namenode");
|
||||
if (nn == NULL) {
|
||||
return std::string("default");
|
||||
} else {
|
||||
return std::string(nn);
|
||||
}
|
||||
}
|
||||
private:
|
||||
hdfsFS fs_;
|
||||
hdfsFile fp_;
|
||||
bool at_end_;
|
||||
bool disconnect_when_done_;
|
||||
};
|
||||
|
||||
/*! \brief line split from normal file system */
|
||||
class HDFSSplit : public LineSplitBase {
|
||||
class HDFSProvider : public LineSplitter::IFileProvider {
|
||||
public:
|
||||
explicit HDFSSplit(const char *uri, unsigned rank, unsigned nsplit) {
|
||||
fs_ = hdfsConnect("default", 0);
|
||||
explicit HDFSProvider(const char *uri) {
|
||||
fs_ = hdfsConnect(HDFSStream::GetNameNode().c_str(), 0);
|
||||
utils::Check(fs_ != NULL, "error when connecting to default HDFS");
|
||||
std::vector<std::string> paths;
|
||||
LineSplitBase::SplitNames(&paths, uri, "#");
|
||||
LineSplitter::SplitNames(&paths, uri, "#");
|
||||
// get the files
|
||||
std::vector<size_t> fsize;
|
||||
for (size_t i = 0; i < paths.size(); ++i) {
|
||||
hdfsFileInfo *info = hdfsGetPathInfo(fs_, paths[i].c_str());
|
||||
utils::Check(info != NULL, "path %s do not exist", paths[i].c_str());
|
||||
if (info->mKind == 'D') {
|
||||
int nentry;
|
||||
hdfsFileInfo *files = hdfsListDirectory(fs_, info->mName, &nentry);
|
||||
utils::Check(files != NULL, "error when ListDirectory %s", info->mName);
|
||||
for (int i = 0; i < nentry; ++i) {
|
||||
if (files[i].mKind == 'F') {
|
||||
fsize.push_back(files[i].mSize);
|
||||
if (files[i].mKind == 'F' && files[i].mSize != 0) {
|
||||
fsize_.push_back(files[i].mSize);
|
||||
fnames_.push_back(std::string(files[i].mName));
|
||||
}
|
||||
}
|
||||
hdfsFreeFileInfo(files, nentry);
|
||||
} else {
|
||||
fsize.push_back(info->mSize);
|
||||
fnames_.push_back(std::string(info->mName));
|
||||
if (info->mSize != 0) {
|
||||
fsize_.push_back(info->mSize);
|
||||
fnames_.push_back(std::string(info->mName));
|
||||
}
|
||||
}
|
||||
hdfsFreeFileInfo(info, 1);
|
||||
}
|
||||
LineSplitBase::Init(fsize, rank, nsplit);
|
||||
}
|
||||
virtual ~HDFSSplit(void) {}
|
||||
|
||||
protected:
|
||||
virtual utils::ISeekStream *GetFile(size_t file_index) {
|
||||
virtual ~HDFSProvider(void) {
|
||||
utils::Check(hdfsDisconnect(fs_) == 0, "hdfsDisconnect error");
|
||||
}
|
||||
virtual const std::vector<size_t> &FileSize(void) const {
|
||||
return fsize_;
|
||||
}
|
||||
virtual ISeekStream *Open(size_t file_index) {
|
||||
utils::Assert(file_index < fnames_.size(), "file index exceed bound");
|
||||
return new HDFSStream(fs_, fnames_[file_index].c_str(), "r");
|
||||
return new HDFSStream(fs_, fnames_[file_index].c_str(), "r", false);
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
// hdfs handle
|
||||
hdfsFS fs_;
|
||||
// file sizes
|
||||
std::vector<size_t> fsize_;
|
||||
// file names
|
||||
std::vector<std::string> fnames_;
|
||||
};
|
||||
|
||||
@ -25,20 +25,21 @@ namespace io {
|
||||
inline InputSplit *CreateInputSplit(const char *uri,
|
||||
unsigned part,
|
||||
unsigned nsplit) {
|
||||
using namespace std;
|
||||
if (!strcmp(uri, "stdin")) {
|
||||
return new SingleFileSplit(uri);
|
||||
}
|
||||
if (!strncmp(uri, "file://", 7)) {
|
||||
return new FileSplit(uri, part, nsplit);
|
||||
return new LineSplitter(new FileProvider(uri), part, nsplit);
|
||||
}
|
||||
if (!strncmp(uri, "hdfs://", 7)) {
|
||||
#if RABIT_USE_HDFS
|
||||
return new HDFSSplit(uri, part, nsplit);
|
||||
return new LineSplitter(new HDFSProvider(uri), part, nsplit);
|
||||
#else
|
||||
utils::Error("Please compile with RABIT_USE_HDFS=1");
|
||||
#endif
|
||||
}
|
||||
return new FileSplit(uri, part, nsplit);
|
||||
return new LineSplitter(new FileProvider(uri), part, nsplit);
|
||||
}
|
||||
/*!
|
||||
* \brief create an stream, the stream must be able to close
|
||||
@ -48,12 +49,14 @@ inline InputSplit *CreateInputSplit(const char *uri,
|
||||
* \param mode can be 'w' or 'r' for read or write
|
||||
*/
|
||||
inline IStream *CreateStream(const char *uri, const char *mode) {
|
||||
using namespace std;
|
||||
if (!strncmp(uri, "file://", 7)) {
|
||||
return new FileStream(uri + 7, mode);
|
||||
}
|
||||
if (!strncmp(uri, "hdfs://", 7)) {
|
||||
#if RABIT_USE_HDFS
|
||||
return new HDFSStream(hdfsConnect("default", 0), uri, mode);
|
||||
return new HDFSStream(hdfsConnect(HDFSStream::GetNameNode().c_str(), 0),
|
||||
uri, mode, true);
|
||||
#else
|
||||
utils::Error("Please compile with RABIT_USE_HDFS=1");
|
||||
#endif
|
||||
|
||||
@ -19,6 +19,7 @@ namespace rabit {
|
||||
* \brief namespace to handle input split and filesystem interfacing
|
||||
*/
|
||||
namespace io {
|
||||
/*! \brief reused ISeekStream's definition */
|
||||
typedef utils::ISeekStream ISeekStream;
|
||||
/*!
|
||||
* \brief user facing input split helper,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
#ifndef RABIT_LEARN_IO_LINE_SPLIT_INL_H_
|
||||
#define RABIT_LEARN_IO_LINE_SPLIT_INL_H_
|
||||
/*!
|
||||
* \file line_split-inl.h
|
||||
* \std::FILE line_split-inl.h
|
||||
* \brief base implementation of line-spliter
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
@ -15,11 +15,42 @@
|
||||
|
||||
namespace rabit {
|
||||
namespace io {
|
||||
class LineSplitBase : public InputSplit {
|
||||
|
||||
/*! \brief class that split the files by line */
|
||||
class LineSplitter : public InputSplit {
|
||||
public:
|
||||
virtual ~LineSplitBase() {
|
||||
if (fs_ != NULL) delete fs_;
|
||||
class IFileProvider {
|
||||
public:
|
||||
/*!
|
||||
* \brief get the seek stream of given file_index
|
||||
* \return the corresponding seek stream at head of the stream
|
||||
* the seek stream's resource can be freed by calling delete
|
||||
*/
|
||||
virtual ISeekStream *Open(size_t file_index) = 0;
|
||||
/*!
|
||||
* \return const reference to size of each files
|
||||
*/
|
||||
virtual const std::vector<size_t> &FileSize(void) const = 0;
|
||||
// virtual destructor
|
||||
virtual ~IFileProvider() {}
|
||||
};
|
||||
// constructor
|
||||
explicit LineSplitter(IFileProvider *provider,
|
||||
unsigned rank,
|
||||
unsigned nsplit)
|
||||
: provider_(provider), fs_(NULL),
|
||||
reader_(kBufferSize) {
|
||||
this->Init(provider_->FileSize(), rank, nsplit);
|
||||
}
|
||||
// destructor
|
||||
virtual ~LineSplitter() {
|
||||
if (fs_ != NULL) {
|
||||
delete fs_; fs_ = NULL;
|
||||
}
|
||||
// delete provider after destructing the streams
|
||||
delete provider_;
|
||||
}
|
||||
// get next line
|
||||
virtual bool NextLine(std::string *out_data) {
|
||||
if (file_ptr_ >= file_ptr_end_ &&
|
||||
offset_curr_ >= offset_end_) return false;
|
||||
@ -29,15 +60,15 @@ class LineSplitBase : public InputSplit {
|
||||
if (reader_.AtEnd()) {
|
||||
if (out_data->length() != 0) return true;
|
||||
file_ptr_ += 1;
|
||||
if (offset_curr_ >= offset_end_) return false;
|
||||
if (offset_curr_ != file_offset_[file_ptr_]) {
|
||||
utils::Error("warning:file size not calculated correctly\n");
|
||||
utils::Error("warning: FILE size not calculated correctly\n");
|
||||
offset_curr_ = file_offset_[file_ptr_];
|
||||
}
|
||||
if (offset_curr_ >= offset_end_) return false;
|
||||
utils::Assert(file_ptr_ + 1 < file_offset_.size(),
|
||||
"boundary check");
|
||||
delete fs_;
|
||||
fs_ = this->GetFile(file_ptr_);
|
||||
fs_ = provider_->Open(file_ptr_);
|
||||
reader_.set_stream(fs_);
|
||||
} else {
|
||||
++offset_curr_;
|
||||
@ -51,12 +82,24 @@ class LineSplitBase : public InputSplit {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// constructor
|
||||
LineSplitBase(void)
|
||||
: fs_(NULL), reader_(kBufferSize) {
|
||||
/*!
|
||||
* \brief split names given
|
||||
* \param out_fname output std::FILE names
|
||||
* \param uri_ the iput uri std::FILE
|
||||
* \param dlm deliminetr
|
||||
*/
|
||||
inline static void SplitNames(std::vector<std::string> *out_fname,
|
||||
const char *uri_,
|
||||
const char *dlm) {
|
||||
std::string uri = uri_;
|
||||
char *p = std::strtok(BeginPtr(uri), dlm);
|
||||
while (p != NULL) {
|
||||
out_fname->push_back(std::string(p));
|
||||
p = std::strtok(NULL, dlm);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/*!
|
||||
* \brief initialize the line spliter,
|
||||
* \param file_size, size of each files
|
||||
@ -82,7 +125,7 @@ class LineSplitBase : public InputSplit {
|
||||
file_ptr_end_ = std::upper_bound(file_offset_.begin(),
|
||||
file_offset_.end(),
|
||||
offset_end_) - file_offset_.begin() - 1;
|
||||
fs_ = GetFile(file_ptr_);
|
||||
fs_ = provider_->Open(file_ptr_);
|
||||
reader_.set_stream(fs_);
|
||||
// try to set the starting position correctly
|
||||
if (file_offset_[file_ptr_] != offset_begin_) {
|
||||
@ -94,28 +137,10 @@ class LineSplitBase : public InputSplit {
|
||||
}
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief get the seek stream of given file_index
|
||||
* \return the corresponding seek stream at head of file
|
||||
*/
|
||||
virtual utils::ISeekStream *GetFile(size_t file_index) = 0;
|
||||
/*!
|
||||
* \brief split names given
|
||||
* \param out_fname output file names
|
||||
* \param uri_ the iput uri file
|
||||
* \param dlm deliminetr
|
||||
*/
|
||||
inline static void SplitNames(std::vector<std::string> *out_fname,
|
||||
const char *uri_,
|
||||
const char *dlm) {
|
||||
std::string uri = uri_;
|
||||
char *p = strtok(BeginPtr(uri), dlm);
|
||||
while (p != NULL) {
|
||||
out_fname->push_back(std::string(p));
|
||||
p = strtok(NULL, dlm);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief FileProvider */
|
||||
IFileProvider *provider_;
|
||||
/*! \brief current input stream */
|
||||
utils::ISeekStream *fs_;
|
||||
/*! \brief file pointer of which file to read on */
|
||||
@ -136,11 +161,11 @@ class LineSplitBase : public InputSplit {
|
||||
const static size_t kBufferSize = 256;
|
||||
};
|
||||
|
||||
/*! \brief line split from single file */
|
||||
/*! \brief line split from single std::FILE */
|
||||
class SingleFileSplit : public InputSplit {
|
||||
public:
|
||||
explicit SingleFileSplit(const char *fname) {
|
||||
if (!strcmp(fname, "stdin")) {
|
||||
if (!std::strcmp(fname, "stdin")) {
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
use_stdin_ = true; fp_ = stdin;
|
||||
#endif
|
||||
@ -151,13 +176,13 @@ class SingleFileSplit : public InputSplit {
|
||||
end_of_file_ = false;
|
||||
}
|
||||
virtual ~SingleFileSplit(void) {
|
||||
if (!use_stdin_) fclose(fp_);
|
||||
if (!use_stdin_) std::fclose(fp_);
|
||||
}
|
||||
virtual bool NextLine(std::string *out_data) {
|
||||
if (end_of_file_) return false;
|
||||
out_data->clear();
|
||||
while (true) {
|
||||
char c = fgetc(fp_);
|
||||
char c = std::fgetc(fp_);
|
||||
if (c == EOF) {
|
||||
end_of_file_ = true;
|
||||
}
|
||||
@ -172,7 +197,7 @@ class SingleFileSplit : public InputSplit {
|
||||
}
|
||||
|
||||
private:
|
||||
FILE *fp_;
|
||||
std::FILE *fp_;
|
||||
bool use_stdin_;
|
||||
bool end_of_file_;
|
||||
};
|
||||
|
||||
@ -1,4 +1,10 @@
|
||||
# specify tensor path
|
||||
ifneq ("$(wildcard ../config.mk)","")
|
||||
config = ../config.mk
|
||||
else
|
||||
config = ../make/config.mk
|
||||
endif
|
||||
include $(config)
|
||||
|
||||
BIN = linear.rabit
|
||||
MOCKBIN= linear.mock
|
||||
MPIBIN =
|
||||
@ -6,10 +12,10 @@ MPIBIN =
|
||||
OBJ = linear.o
|
||||
|
||||
# common build script for programs
|
||||
include ../make/config.mk
|
||||
include ../make/common.mk
|
||||
CFLAGS+=-fopenmp
|
||||
linear.o: linear.cc ../../src/*.h linear.h ../solver/*.h
|
||||
# dependenies here
|
||||
linear.rabit: linear.o lib
|
||||
linear.mock: linear.o lib
|
||||
|
||||
|
||||
@ -206,21 +206,22 @@ int main(int argc, char *argv[]) {
|
||||
rabit::Finalize();
|
||||
return 0;
|
||||
}
|
||||
rabit::linear::LinearObjFunction linear;
|
||||
rabit::linear::LinearObjFunction *linear = new rabit::linear::LinearObjFunction();
|
||||
if (!strcmp(argv[1], "stdin")) {
|
||||
linear.LoadData(argv[1]);
|
||||
linear->LoadData(argv[1]);
|
||||
rabit::Init(argc, argv);
|
||||
} else {
|
||||
rabit::Init(argc, argv);
|
||||
linear.LoadData(argv[1]);
|
||||
linear->LoadData(argv[1]);
|
||||
}
|
||||
for (int i = 2; i < argc; ++i) {
|
||||
char name[256], val[256];
|
||||
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
|
||||
linear.SetParam(name, val);
|
||||
linear->SetParam(name, val);
|
||||
}
|
||||
}
|
||||
linear.Run();
|
||||
linear->Run();
|
||||
delete linear;
|
||||
rabit::Finalize();
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -26,10 +26,11 @@ struct LinearModel {
|
||||
int reserved[16];
|
||||
// constructor
|
||||
ModelParam(void) {
|
||||
memset(this, 0, sizeof(ModelParam));
|
||||
base_score = 0.5f;
|
||||
num_feature = 0;
|
||||
loss_type = 1;
|
||||
std::memset(reserved, 0, sizeof(reserved));
|
||||
num_feature = 0;
|
||||
}
|
||||
// initialize base score
|
||||
inline void InitBaseScore(void) {
|
||||
@ -119,7 +120,7 @@ struct LinearModel {
|
||||
}
|
||||
fi.Read(weight, sizeof(float) * (param.num_feature + 1));
|
||||
}
|
||||
inline void Save(rabit::IStream &fo, const float *wptr = NULL) const {
|
||||
inline void Save(rabit::IStream &fo, const float *wptr = NULL) {
|
||||
fo.Write(¶m, sizeof(param));
|
||||
if (wptr == NULL) wptr = weight;
|
||||
fo.Write(wptr, sizeof(float) * (param.num_feature + 1));
|
||||
|
||||
@ -6,12 +6,13 @@ then
|
||||
fi
|
||||
|
||||
# put the local training file to HDFS
|
||||
hadoop fs -rm -r -f $2/data
|
||||
hadoop fs -rm -r -f $2/mushroom.linear.model
|
||||
|
||||
hadoop fs -mkdir $2/data
|
||||
hadoop fs -put ../data/agaricus.txt.train $2/data
|
||||
|
||||
# submit to hadoop
|
||||
../../tracker/rabit_yarn.py -n $1 --vcores 1 linear.rabit hdfs://$2/data/agaricus.txt.train model_out=hdfs://$2/mushroom.linear.model "${*:3}"
|
||||
../../tracker/rabit_yarn.py -n $1 --vcores 1 ./linear.rabit hdfs://$2/data/agaricus.txt.train model_out=hdfs://$2/mushroom.linear.model "${*:3}"
|
||||
|
||||
# get the final model file
|
||||
hadoop fs -get $2/mushroom.linear.model ./linear.model
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
#
|
||||
# - copy this file to the root of rabit-learn folder
|
||||
# - modify the configuration you want
|
||||
# - type make or make -j n for parallel build
|
||||
# - type make or make -j n on each of the folder
|
||||
#----------------------------------------------------
|
||||
|
||||
# choice of compiler
|
||||
|
||||
@ -145,8 +145,9 @@ class LBFGSSolver {
|
||||
|
||||
if (silent == 0 && rabit::GetRank() == 0) {
|
||||
rabit::TrackerPrintf
|
||||
("L-BFGS solver starts, num_dim=%lu, init_objval=%g, size_memory=%lu\n",
|
||||
gstate.num_dim, gstate.init_objval, gstate.size_memory);
|
||||
("L-BFGS solver starts, num_dim=%lu, init_objval=%g, size_memory=%lu, RAM-approx=%lu\n",
|
||||
gstate.num_dim, gstate.init_objval, gstate.size_memory,
|
||||
gstate.MemCost() + hist.MemCost());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -176,7 +177,7 @@ class LBFGSSolver {
|
||||
// swap new weight
|
||||
std::swap(g.weight, g.grad);
|
||||
// check stop condition
|
||||
if (gstate.num_iteration > min_lbfgs_iter) {
|
||||
if (gstate.num_iteration > static_cast<size_t>(min_lbfgs_iter)) {
|
||||
if (g.old_objval - g.new_objval < lbfgs_stop_tol * g.init_objval) {
|
||||
return true;
|
||||
}
|
||||
@ -195,7 +196,7 @@ class LBFGSSolver {
|
||||
/*! \brief run optimization */
|
||||
virtual void Run(void) {
|
||||
this->Init();
|
||||
while (gstate.num_iteration < max_lbfgs_iter) {
|
||||
while (gstate.num_iteration < static_cast<size_t>(max_lbfgs_iter)) {
|
||||
if (this->UpdateOneIter()) break;
|
||||
}
|
||||
if (silent == 0 && rabit::GetRank() == 0) {
|
||||
@ -225,7 +226,7 @@ class LBFGSSolver {
|
||||
const size_t num_dim = gstate.num_dim;
|
||||
const DType *gsub = grad + range_begin_;
|
||||
const size_t nsub = range_end_ - range_begin_;
|
||||
double vdot;
|
||||
double vdot = 0.0;
|
||||
if (n != 0) {
|
||||
// hist[m + n - 1] stores old gradient
|
||||
Minus(hist[m + n - 1], gsub, hist[m + n - 1], nsub);
|
||||
@ -241,15 +242,19 @@ class LBFGSSolver {
|
||||
idxset.push_back(std::make_pair(m + j, 2 * m));
|
||||
idxset.push_back(std::make_pair(m + j, m + n - 1));
|
||||
}
|
||||
|
||||
// calculate dot products
|
||||
std::vector<double> tmp(idxset.size());
|
||||
for (size_t i = 0; i < tmp.size(); ++i) {
|
||||
tmp[i] = hist.CalcDot(idxset[i].first, idxset[i].second);
|
||||
}
|
||||
|
||||
rabit::Allreduce<rabit::op::Sum>(BeginPtr(tmp), tmp.size());
|
||||
|
||||
for (size_t i = 0; i < tmp.size(); ++i) {
|
||||
gstate.DotBuf(idxset[i].first, idxset[i].second) = tmp[i];
|
||||
}
|
||||
|
||||
// BFGS steps, use vector-free update
|
||||
// parameterize vector using basis in hist
|
||||
std::vector<double> alpha(n);
|
||||
@ -263,7 +268,7 @@ class LBFGSSolver {
|
||||
}
|
||||
alpha[j] = vsum / gstate.DotBuf(j, m + j);
|
||||
delta[m + j] = delta[m + j] - alpha[j];
|
||||
}
|
||||
}
|
||||
// scale
|
||||
double scale = gstate.DotBuf(n - 1, m + n - 1) /
|
||||
gstate.DotBuf(m + n - 1, m + n - 1);
|
||||
@ -279,6 +284,7 @@ class LBFGSSolver {
|
||||
double beta = vsum / gstate.DotBuf(j, m + j);
|
||||
delta[j] = delta[j] + (alpha[j] - beta);
|
||||
}
|
||||
|
||||
// set all to zero
|
||||
std::fill(dir, dir + num_dim, 0.0f);
|
||||
DType *dirsub = dir + range_begin_;
|
||||
@ -291,10 +297,11 @@ class LBFGSSolver {
|
||||
}
|
||||
FixDirL1Sign(dirsub, hist[2 * m], nsub);
|
||||
vdot = -Dot(dirsub, hist[2 * m], nsub);
|
||||
|
||||
// allreduce to get full direction
|
||||
rabit::Allreduce<rabit::op::Sum>(dir, num_dim);
|
||||
rabit::Allreduce<rabit::op::Sum>(&vdot, 1);
|
||||
} else {
|
||||
} else {
|
||||
SetL1Dir(dir, grad, weight, num_dim);
|
||||
vdot = -Dot(dir, dir, num_dim);
|
||||
}
|
||||
@ -482,6 +489,7 @@ class LBFGSSolver {
|
||||
num_iteration = 0;
|
||||
num_dim = 0;
|
||||
old_objval = 0.0;
|
||||
offset_ = 0;
|
||||
}
|
||||
~GlobalState(void) {
|
||||
if (grad != NULL) {
|
||||
@ -496,6 +504,10 @@ class LBFGSSolver {
|
||||
data.resize(n * n, 0.0);
|
||||
this->AllocSpace();
|
||||
}
|
||||
// memory cost
|
||||
inline size_t MemCost(void) const {
|
||||
return sizeof(DType) * 3 * num_dim;
|
||||
}
|
||||
inline double &DotBuf(size_t i, size_t j) {
|
||||
if (i > j) std::swap(i, j);
|
||||
return data[MapIndex(i, offset_, size_memory) * (size_memory * 2 + 1) +
|
||||
@ -565,6 +577,10 @@ class LBFGSSolver {
|
||||
size_t n = size_memory * 2 + 1;
|
||||
dptr_ = new DType[n * stride_];
|
||||
}
|
||||
// memory cost
|
||||
inline size_t MemCost(void) const {
|
||||
return sizeof(DType) * (size_memory_ * 2 + 1) * stride_;
|
||||
}
|
||||
// fetch element from rolling array
|
||||
inline const DType *operator[](size_t i) const {
|
||||
return dptr_ + MapIndex(i, offset_, size_memory_) * stride_;
|
||||
|
||||
@ -77,11 +77,15 @@ struct SparseMat {
|
||||
feat_dim += 1;
|
||||
utils::Check(feat_dim < std::numeric_limits<index_t>::max(),
|
||||
"feature dimension exceed limit of index_t"\
|
||||
"consider change the index_t to unsigned long");
|
||||
"consider change the index_t to unsigned long");
|
||||
}
|
||||
inline size_t NumRow(void) const {
|
||||
return row_ptr.size() - 1;
|
||||
}
|
||||
// memory cost
|
||||
inline size_t MemCost(void) const {
|
||||
return data.size() * sizeof(Entry);
|
||||
}
|
||||
// maximum feature dimension
|
||||
size_t feat_dim;
|
||||
std::vector<size_t> row_ptr;
|
||||
|
||||
@ -26,6 +26,9 @@ AllreduceBase::AllreduceBase(void) {
|
||||
world_size = -1;
|
||||
hadoop_mode = 0;
|
||||
version_number = 0;
|
||||
// 32 K items
|
||||
reduce_ring_mincount = 32 << 10;
|
||||
// tracker URL
|
||||
task_id = "NULL";
|
||||
err_link = NULL;
|
||||
this->SetParam("rabit_reduce_buffer", "256MB");
|
||||
@ -33,7 +36,8 @@ AllreduceBase::AllreduceBase(void) {
|
||||
env_vars.push_back("rabit_task_id");
|
||||
env_vars.push_back("rabit_num_trial");
|
||||
env_vars.push_back("rabit_reduce_buffer");
|
||||
env_vars.push_back("rabit_tracker_uri");
|
||||
env_vars.push_back("rabit_reduce_ring_mincount");
|
||||
env_vars.push_back("rabit_tracker_uri");
|
||||
env_vars.push_back("rabit_tracker_port");
|
||||
}
|
||||
|
||||
@ -116,6 +120,27 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
|
||||
tracker.SendStr(msg);
|
||||
tracker.Close();
|
||||
}
|
||||
// util to parse data with unit suffix
|
||||
inline size_t ParseUnit(const char *name, const char *val) {
|
||||
char unit;
|
||||
uint64_t amount;
|
||||
int n = sscanf(val, "%lu%c", &amount, &unit);
|
||||
if (n == 2) {
|
||||
switch (unit) {
|
||||
case 'B': return amount;
|
||||
case 'K': return amount << 10UL;
|
||||
case 'M': return amount << 20UL;
|
||||
case 'G': return amount << 30UL;
|
||||
default: utils::Error("invalid format for %s", name); return 0;
|
||||
}
|
||||
} else if (n == 1) {
|
||||
return amount;
|
||||
} else {
|
||||
utils::Error("invalid format for %s," \
|
||||
"shhould be {integer}{unit}, unit can be {B, KB, MB, GB}", name);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
* \param name parameter name
|
||||
@ -127,21 +152,11 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
if (!strcmp(name, "rabit_task_id")) task_id = val;
|
||||
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
|
||||
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val);
|
||||
if (!strcmp(name, "rabit_reduce_ring_mincount")) {
|
||||
reduce_ring_mincount = ParseUnit(name, val);
|
||||
}
|
||||
if (!strcmp(name, "rabit_reduce_buffer")) {
|
||||
char unit;
|
||||
uint64_t amount;
|
||||
if (sscanf(val, "%lu%c", &amount, &unit) == 2) {
|
||||
switch (unit) {
|
||||
case 'B': reduce_buffer_size = (amount + 7)/ 8; break;
|
||||
case 'K': reduce_buffer_size = amount << 7UL; break;
|
||||
case 'M': reduce_buffer_size = amount << 17UL; break;
|
||||
case 'G': reduce_buffer_size = amount << 27UL; break;
|
||||
default: utils::Error("invalid format for reduce buffer");
|
||||
}
|
||||
} else {
|
||||
utils::Error("invalid format for reduce_buffer,"\
|
||||
"shhould be {integer}{unit}, unit can be {B, KB, MB, GB}");
|
||||
}
|
||||
reduce_buffer_size = (ParseUnit(name, val) + 7) >> 3;
|
||||
}
|
||||
}
|
||||
/*!
|
||||
@ -341,6 +356,28 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer) {
|
||||
if (count > reduce_ring_mincount) {
|
||||
return this->TryAllreduceRing(sendrecvbuf_, type_nbytes, count, reducer);
|
||||
} else {
|
||||
return this->TryAllreduceTree(sendrecvbuf_, type_nbytes, count, reducer);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf,
|
||||
* this function implements tree-shape reduction
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
AllreduceBase::ReturnType
|
||||
AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer) {
|
||||
RefLinkVector &links = tree_links;
|
||||
if (links.size() == 0 || count == 0) return kSuccess;
|
||||
// total size of message
|
||||
@ -411,7 +448,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
||||
// read data from childs
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index && selecter.CheckRead(links[i].sock)) {
|
||||
ReturnType ret = links[i].ReadToRingBuffer(size_up_out);
|
||||
ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size);
|
||||
if (ret != kSuccess) {
|
||||
return ReportError(&links[i], ret);
|
||||
}
|
||||
@ -599,5 +636,217 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
/*!
|
||||
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
*/
|
||||
AllreduceBase::ReturnType
|
||||
AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice) {
|
||||
// read from next link and send to prev one
|
||||
LinkRecord &prev = *ring_prev, &next = *ring_next;
|
||||
// need to reply on special rank structure
|
||||
utils::Assert(next.rank == (rank + 1) % world_size &&
|
||||
rank == (prev.rank + 1) % world_size,
|
||||
"need to assume rank structure");
|
||||
// send recv buffer
|
||||
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||
const size_t stop_read = total_size + slice_begin;
|
||||
const size_t stop_write = total_size + slice_begin - size_prev_slice;
|
||||
size_t write_ptr = slice_begin;
|
||||
size_t read_ptr = slice_end;
|
||||
|
||||
while (true) {
|
||||
// select helper
|
||||
bool finished = true;
|
||||
utils::SelectHelper selecter;
|
||||
if (read_ptr != stop_read) {
|
||||
selecter.WatchRead(next.sock);
|
||||
finished = false;
|
||||
}
|
||||
if (write_ptr != stop_write) {
|
||||
if (write_ptr < read_ptr) {
|
||||
selecter.WatchWrite(prev.sock);
|
||||
}
|
||||
finished = false;
|
||||
}
|
||||
if (finished) break;
|
||||
selecter.Select();
|
||||
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
|
||||
size_t size = stop_read - read_ptr;
|
||||
size_t start = read_ptr % total_size;
|
||||
if (start + size > total_size) {
|
||||
size = total_size - start;
|
||||
}
|
||||
ssize_t len = next.sock.Recv(sendrecvbuf + start, size);
|
||||
if (len != -1) {
|
||||
read_ptr += static_cast<size_t>(len);
|
||||
} else {
|
||||
ReturnType ret = Errno2Return(errno);
|
||||
if (ret != kSuccess) return ReportError(&next, ret);
|
||||
}
|
||||
}
|
||||
if (write_ptr < read_ptr && write_ptr != stop_write) {
|
||||
size_t size = std::min(read_ptr, stop_write) - write_ptr;
|
||||
size_t start = write_ptr % total_size;
|
||||
if (start + size > total_size) {
|
||||
size = total_size - start;
|
||||
}
|
||||
ssize_t len = prev.sock.Send(sendrecvbuf + start, size);
|
||||
if (len != -1) {
|
||||
write_ptr += static_cast<size_t>(len);
|
||||
} else {
|
||||
ReturnType ret = Errno2Return(errno);
|
||||
if (ret != kSuccess) return ReportError(&prev, ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail,
|
||||
* and will return the cause of failure
|
||||
*
|
||||
* Ring-based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType, TryAllreduce
|
||||
*/
|
||||
AllreduceBase::ReturnType
|
||||
AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer) {
|
||||
// read from next link and send to prev one
|
||||
LinkRecord &prev = *ring_prev, &next = *ring_next;
|
||||
// need to reply on special rank structure
|
||||
utils::Assert(next.rank == (rank + 1) % world_size &&
|
||||
rank == (prev.rank + 1) % world_size,
|
||||
"need to assume rank structure");
|
||||
// total size of message
|
||||
const size_t total_size = type_nbytes * count;
|
||||
size_t n = static_cast<size_t>(world_size);
|
||||
size_t step = (count + n - 1) / n;
|
||||
size_t r = static_cast<size_t>(next.rank);
|
||||
size_t write_ptr = std::min(r * step, count) * type_nbytes;
|
||||
size_t read_ptr = std::min((r + 1) * step, count) * type_nbytes;
|
||||
size_t reduce_ptr = read_ptr;
|
||||
// send recv buffer
|
||||
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||
// position to stop reading
|
||||
const size_t stop_read = total_size + write_ptr;
|
||||
// position to stop writing
|
||||
size_t stop_write = total_size + std::min(rank * step, count) * type_nbytes;
|
||||
if (stop_write > stop_read) {
|
||||
stop_write -= total_size;
|
||||
utils::Assert(write_ptr <= stop_write, "write ptr boundary check");
|
||||
}
|
||||
// use ring buffer in next position
|
||||
next.InitBuffer(type_nbytes, step, reduce_buffer_size);
|
||||
// set size_read to read pointer for ring buffer to work properly
|
||||
next.size_read = read_ptr;
|
||||
|
||||
while (true) {
|
||||
// select helper
|
||||
bool finished = true;
|
||||
utils::SelectHelper selecter;
|
||||
if (read_ptr != stop_read) {
|
||||
selecter.WatchRead(next.sock);
|
||||
finished = false;
|
||||
}
|
||||
if (write_ptr != stop_write) {
|
||||
if (write_ptr < reduce_ptr) {
|
||||
selecter.WatchWrite(prev.sock);
|
||||
}
|
||||
finished = false;
|
||||
}
|
||||
if (finished) break;
|
||||
selecter.Select();
|
||||
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
|
||||
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
|
||||
if (ret != kSuccess) {
|
||||
return ReportError(&next, ret);
|
||||
}
|
||||
// sync the rate
|
||||
read_ptr = next.size_read;
|
||||
utils::Assert(read_ptr <= stop_read, "[%d] read_ptr boundary check", rank);
|
||||
const size_t buffer_size = next.buffer_size;
|
||||
size_t max_reduce = (read_ptr / type_nbytes) * type_nbytes;
|
||||
while (reduce_ptr < max_reduce) {
|
||||
size_t bstart = reduce_ptr % buffer_size;
|
||||
size_t nread = std::min(buffer_size - bstart,
|
||||
max_reduce - reduce_ptr);
|
||||
size_t rstart = reduce_ptr % total_size;
|
||||
nread = std::min(nread, total_size - rstart);
|
||||
reducer(next.buffer_head + bstart,
|
||||
sendrecvbuf + rstart,
|
||||
static_cast<int>(nread / type_nbytes),
|
||||
MPI::Datatype(type_nbytes));
|
||||
reduce_ptr += nread;
|
||||
}
|
||||
}
|
||||
if (write_ptr < reduce_ptr && write_ptr != stop_write) {
|
||||
size_t size = std::min(reduce_ptr, stop_write) - write_ptr;
|
||||
size_t start = write_ptr % total_size;
|
||||
if (start + size > total_size) {
|
||||
size = total_size - start;
|
||||
}
|
||||
ssize_t len = prev.sock.Send(sendrecvbuf + start, size);
|
||||
if (len != -1) {
|
||||
write_ptr += static_cast<size_t>(len);
|
||||
} else {
|
||||
ReturnType ret = Errno2Return(errno);
|
||||
if (ret != kSuccess) return ReportError(&prev, ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
AllreduceBase::ReturnType
|
||||
AllreduceBase::TryAllreduceRing(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer) {
|
||||
ReturnType ret = TryReduceScatterRing(sendrecvbuf_, type_nbytes, count, reducer);
|
||||
if (ret != kSuccess) return ret;
|
||||
size_t n = static_cast<size_t>(world_size);
|
||||
size_t step = (count + n - 1) / n;
|
||||
size_t begin = std::min(rank * step, count) * type_nbytes;
|
||||
size_t end = std::min((rank + 1) * step, count) * type_nbytes;
|
||||
// previous rank
|
||||
int prank = ring_prev->rank;
|
||||
// get rank of previous
|
||||
return TryAllgatherRing
|
||||
(sendrecvbuf_, type_nbytes * count,
|
||||
begin, end,
|
||||
(std::min((prank + 1) * step, count) -
|
||||
std::min(prank * step, count)) * type_nbytes);
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
|
||||
@ -278,15 +278,19 @@ class AllreduceBase : public IEngine {
|
||||
* \brief read data into ring-buffer, with care not to existing useful override data
|
||||
* position after protect_start
|
||||
* \param protect_start all data start from protect_start is still needed in buffer
|
||||
* read shall not override this
|
||||
* read shall not override this
|
||||
* \param max_size_read maximum logical amount we can read, size_read cannot exceed this value
|
||||
* \return the type of reading
|
||||
*/
|
||||
inline ReturnType ReadToRingBuffer(size_t protect_start) {
|
||||
inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) {
|
||||
utils::Assert(buffer_head != NULL, "ReadToRingBuffer: buffer not allocated");
|
||||
utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check");
|
||||
size_t ngap = size_read - protect_start;
|
||||
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
|
||||
size_t offset = size_read % buffer_size;
|
||||
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
|
||||
size_t nmax = max_size_read - size_read;
|
||||
nmax = std::min(nmax, buffer_size - ngap);
|
||||
nmax = std::min(nmax, buffer_size - offset);
|
||||
if (nmax == 0) return kSuccess;
|
||||
ssize_t len = sock.Recv(buffer_head + offset, nmax);
|
||||
// length equals 0, remote disconnected
|
||||
@ -380,13 +384,79 @@ class AllreduceBase : public IEngine {
|
||||
ReduceFunction reducer);
|
||||
/*!
|
||||
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||
* \param size the size of the data to be broadcasted
|
||||
* \param root the root worker id to broadcast the data
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
|
||||
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf,
|
||||
* this function implements tree-shape reduction
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryAllreduceTree(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer);
|
||||
/*!
|
||||
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
|
||||
size_t slice_begin, size_t slice_end,
|
||||
size_t size_prev_slice);
|
||||
/*!
|
||||
* \brief perform in-place allreduce, reduce on the sendrecvbuf,
|
||||
*
|
||||
* after the function, node k get k-th segment of the reduction result
|
||||
* the k-th segment is defined by [k * step, min((k + 1) * step,count) )
|
||||
* where step = ceil(count / world_size)
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType, TryAllreduce
|
||||
*/
|
||||
ReturnType TryReduceScatterRing(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer);
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* use a ring based algorithm, reduce-scatter + allgather
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryAllreduceRing(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer);
|
||||
/*!
|
||||
* \brief function used to report error when a link goes wrong
|
||||
* \param link the pointer to the link who causes the error
|
||||
@ -432,6 +502,10 @@ class AllreduceBase : public IEngine {
|
||||
int slave_port, nport_trial;
|
||||
// reduce buffer size
|
||||
size_t reduce_buffer_size;
|
||||
// reduction method
|
||||
int reduce_method;
|
||||
// mininum count of cells to use ring based method
|
||||
size_t reduce_ring_mincount;
|
||||
// current rank
|
||||
int rank;
|
||||
// world size
|
||||
|
||||
@ -81,18 +81,18 @@ class AllreduceMock : public AllreduceRobust {
|
||||
ComboSerializer com(global_model, local_model);
|
||||
AllreduceRobust::CheckPoint(&dum, &com);
|
||||
}
|
||||
tsum_allreduce = 0.0;
|
||||
time_checkpoint = utils::GetTime();
|
||||
double tcost = utils::GetTime() - tstart;
|
||||
if (report_stats != 0 && rank == 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[v" << version_number << "] global_size=" << global_checkpoint.length()
|
||||
<< "local_size=" << local_chkpt[local_chkpt_version].length()
|
||||
<< "check_tcost="<< tcost <<" sec,"
|
||||
<< "allreduce_tcost=" << tsum_allreduce << " sec,"
|
||||
<< "between_chpt=" << tbet_chkpt << "sec\n";
|
||||
<< ",local_size=" << (local_chkpt[0].length() + local_chkpt[1].length())
|
||||
<< ",check_tcost="<< tcost <<" sec"
|
||||
<< ",allreduce_tcost=" << tsum_allreduce << " sec"
|
||||
<< ",between_chpt=" << tbet_chkpt << "sec\n";
|
||||
this->TrackerPrint(ss.str());
|
||||
}
|
||||
tsum_allreduce = 0.0;
|
||||
}
|
||||
|
||||
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||
|
||||
@ -677,7 +677,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
||||
if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
|
||||
}
|
||||
utils::Assert(min_write <= links[pid].size_read, "boundary check");
|
||||
ReturnType ret = links[pid].ReadToRingBuffer(min_write);
|
||||
ReturnType ret = links[pid].ReadToRingBuffer(min_write, size);
|
||||
if (ret != kSuccess) {
|
||||
return ReportError(&links[pid], ret);
|
||||
}
|
||||
|
||||
@ -287,7 +287,6 @@ class AllreduceRobust : public AllreduceBase {
|
||||
if (seqno_.size() == 0) return -1;
|
||||
return seqno_.back();
|
||||
}
|
||||
|
||||
private:
|
||||
// sequence number of each
|
||||
std::vector<int> seqno_;
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
export CC = gcc
|
||||
export CXX = g++
|
||||
export MPICXX = mpicxx
|
||||
export LDFLAGS= -pthread -lm -lrt -L../lib
|
||||
export LDFLAGS= -L../lib -pthread -lm -lrt
|
||||
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../include -std=c++11
|
||||
|
||||
# specify tensor path
|
||||
@ -29,7 +29,7 @@ local_recover: local_recover.o $(RABIT_OBJ)
|
||||
lazy_recover: lazy_recover.o $(RABIT_OBJ)
|
||||
|
||||
$(BIN) :
|
||||
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit_mock
|
||||
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) -lrabit_mock $(LDFLAGS)
|
||||
|
||||
$(OBJ) :
|
||||
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
||||
|
||||
@ -23,4 +23,7 @@ lazy_recover_10_10k_die_hard:
|
||||
../tracker/rabit_demo.py -n 10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0
|
||||
|
||||
lazy_recover_10_10k_die_same:
|
||||
../tracker/rabit_demo.py -n 10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0
|
||||
../tracker/rabit_demo.py -n 10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0
|
||||
|
||||
ringallreduce_10_10k:
|
||||
../tracker/rabit_demo.py -v 1 -n 10 model_recover 100 rabit_reduce_ring_mincount=10
|
||||
|
||||
@ -9,4 +9,4 @@ the example guidelines are in the script themselfs
|
||||
* Yarn (Hadoop): [rabit_yarn.py](rabit_yarn.py)
|
||||
- It is also possible to submit via hadoop streaming with rabit_hadoop_streaming.py
|
||||
- However, it is higly recommended to use rabit_yarn.py because this will allocate resources more precisely and fits machine learning scenarios
|
||||
|
||||
* Sun Grid engine: [rabit_sge.py](rabit_sge.py)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
#!/usr/bin/python
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
This is the demo submission script of rabit, it is created to
|
||||
submit rabit jobs using hadoop streaming
|
||||
This is the demo submission script of rabit for submitting jobs in local machine
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
@ -43,7 +42,7 @@ def exec_cmd(cmd, taskid, worker_env):
|
||||
if cmd[0].find('/') == -1 and os.path.exists(cmd[0]) and os.name != 'nt':
|
||||
cmd[0] = './' + cmd[0]
|
||||
cmd = ' '.join(cmd)
|
||||
env = {}
|
||||
env = os.environ.copy()
|
||||
for k, v in worker_env.items():
|
||||
env[k] = str(v)
|
||||
env['rabit_task_id'] = str(taskid)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
#!/usr/bin/python
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Deprecated
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
#!/usr/bin/python
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
This is the demo submission script of rabit, it is created to
|
||||
submit rabit jobs using hadoop streaming
|
||||
Submission script to submit rabit jobs using MPI
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
69
tracker/rabit_sge.py
Executable file
69
tracker/rabit_sge.py
Executable file
@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Submit rabit jobs to Sun Grid Engine
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
import rabit_tracker as tracker
|
||||
|
||||
parser = argparse.ArgumentParser(description='Rabit script to submit rabit job using MPI')
|
||||
parser.add_argument('-n', '--nworker', required=True, type=int,
|
||||
help = 'number of worker proccess to be launched')
|
||||
parser.add_argument('-q', '--queue', default='default', type=str,
|
||||
help = 'the queue we want to submit the job to')
|
||||
parser.add_argument('-hip', '--host_ip', default='auto', type=str,
|
||||
help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine')
|
||||
parser.add_argument('--vcores', default = 1, type=int,
|
||||
help = 'number of vcpores to request in each mapper, set it if each rabit job is multi-threaded')
|
||||
parser.add_argument('--jobname', default='auto', help = 'customize jobname in tracker')
|
||||
parser.add_argument('--logdir', default='auto', help = 'customize the directory to place the logs')
|
||||
parser.add_argument('-v', '--verbose', default=0, choices=[0, 1], type=int,
|
||||
help = 'print more messages into the console')
|
||||
parser.add_argument('command', nargs='+',
|
||||
help = 'command for rabit program')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.jobname == 'auto':
|
||||
args.jobname = ('rabit%d.' % args.nworker) + args.command[0].split('/')[-1];
|
||||
if args.logdir == 'auto':
|
||||
args.logdir = args.jobname + '.log'
|
||||
|
||||
if os.path.exists(args.logdir):
|
||||
if not os.path.isdir(args.logdir):
|
||||
raise RuntimeError('specified logdir %s is a file instead of directory' % args.logdir)
|
||||
else:
|
||||
os.mkdir(args.logdir)
|
||||
|
||||
runscript = '%s/runrabit.sh' % args.logdir
|
||||
fo = open(runscript, 'w')
|
||||
fo.write('\"$@\"\n')
|
||||
fo.close()
|
||||
#
|
||||
# submission script using MPI
|
||||
#
|
||||
def sge_submit(nslave, worker_args, worker_envs):
|
||||
"""
|
||||
customized submit script, that submit nslave jobs, each must contain args as parameter
|
||||
note this can be a lambda function containing additional parameters in input
|
||||
Parameters
|
||||
nslave number of slave process to start up
|
||||
args arguments to launch each job
|
||||
this usually includes the parameters of master_uri and parameters passed into submit
|
||||
"""
|
||||
env_arg = ','.join(['%s=\"%s\"' % (k, str(v)) for k, v in worker_envs.items()])
|
||||
cmd = 'qsub -cwd -t 1-%d -S /bin/bash' % nslave
|
||||
if args.queue != 'default':
|
||||
cmd += '-q %s' % args.queue
|
||||
cmd += ' -N %s ' % args.jobname
|
||||
cmd += ' -e %s -o %s' % (args.logdir, args.logdir)
|
||||
cmd += ' -pe orte %d' % (args.vcores)
|
||||
cmd += ' -v %s,PATH=${PATH}:.' % env_arg
|
||||
cmd += ' %s %s' % (runscript, ' '.join(args.command + worker_args))
|
||||
print cmd
|
||||
subprocess.check_call(cmd, shell = True)
|
||||
print 'Waiting for the jobs to get up...'
|
||||
|
||||
# call submit, with nslave, the commands to run each job and submit function
|
||||
tracker.submit(args.nworker, [], fun_submit = sge_submit, verbose = args.verbose)
|
||||
@ -13,6 +13,7 @@ import socket
|
||||
import struct
|
||||
import subprocess
|
||||
import random
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
"""
|
||||
@ -188,6 +189,7 @@ class Tracker:
|
||||
vlst.reverse()
|
||||
rlst += vlst
|
||||
return rlst
|
||||
|
||||
def get_ring(self, tree_map, parent_map):
|
||||
"""
|
||||
get a ring connection used to recover local data
|
||||
@ -202,14 +204,44 @@ class Tracker:
|
||||
rnext = (r + 1) % nslave
|
||||
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
|
||||
return ring_map
|
||||
|
||||
def get_link_map(self, nslave):
|
||||
"""
|
||||
get the link map, this is a bit hacky, call for better algorithm
|
||||
to place similar nodes together
|
||||
"""
|
||||
tree_map, parent_map = self.get_tree(nslave)
|
||||
ring_map = self.get_ring(tree_map, parent_map)
|
||||
rmap = {0 : 0}
|
||||
k = 0
|
||||
for i in range(nslave - 1):
|
||||
k = ring_map[k][1]
|
||||
rmap[k] = i + 1
|
||||
|
||||
ring_map_ = {}
|
||||
tree_map_ = {}
|
||||
parent_map_ ={}
|
||||
for k, v in ring_map.items():
|
||||
ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]])
|
||||
for k, v in tree_map.items():
|
||||
tree_map_[rmap[k]] = [rmap[x] for x in v]
|
||||
for k, v in parent_map.items():
|
||||
if k != 0:
|
||||
parent_map_[rmap[k]] = rmap[v]
|
||||
else:
|
||||
parent_map_[rmap[k]] = -1
|
||||
return tree_map_, parent_map_, ring_map_
|
||||
|
||||
def handle_print(self,slave, msg):
|
||||
sys.stdout.write(msg)
|
||||
|
||||
def log_print(self, msg, level):
|
||||
if level == 1:
|
||||
if self.verbose:
|
||||
sys.stderr.write(msg + '\n')
|
||||
else:
|
||||
sys.stderr.write(msg + '\n')
|
||||
|
||||
def accept_slaves(self, nslave):
|
||||
# set of nodes that finishs the job
|
||||
shutdown = {}
|
||||
@ -241,31 +273,40 @@ class Tracker:
|
||||
assert s.cmd == 'start'
|
||||
if s.world_size > 0:
|
||||
nslave = s.world_size
|
||||
tree_map, parent_map = self.get_tree(nslave)
|
||||
ring_map = self.get_ring(tree_map, parent_map)
|
||||
tree_map, parent_map, ring_map = self.get_link_map(nslave)
|
||||
# set of nodes that is pending for getting up
|
||||
todo_nodes = range(nslave)
|
||||
random.shuffle(todo_nodes)
|
||||
else:
|
||||
assert s.world_size == -1 or s.world_size == nslave
|
||||
if s.cmd == 'recover':
|
||||
assert s.rank >= 0
|
||||
|
||||
rank = s.decide_rank(job_map)
|
||||
# batch assignment of ranks
|
||||
if rank == -1:
|
||||
assert len(todo_nodes) != 0
|
||||
rank = todo_nodes.pop(0)
|
||||
if s.jobid != 'NULL':
|
||||
job_map[s.jobid] = rank
|
||||
pending.append(s)
|
||||
if len(pending) == len(todo_nodes):
|
||||
pending.sort(key = lambda x : x.host)
|
||||
for s in pending:
|
||||
rank = todo_nodes.pop(0)
|
||||
if s.jobid != 'NULL':
|
||||
job_map[s.jobid] = rank
|
||||
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||
if s.wait_accept > 0:
|
||||
wait_conn[rank] = s
|
||||
self.log_print('Recieve %s signal from %s; assign rank %d' % (s.cmd, s.host, s.rank), 1)
|
||||
if len(todo_nodes) == 0:
|
||||
self.log_print('@tracker All of %d nodes getting started' % nslave, 2)
|
||||
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||
if s.cmd != 'start':
|
||||
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank), 1)
|
||||
self.start_time = time.time()
|
||||
else:
|
||||
self.log_print('Recieve %s signal from %s; assign rank %d' % (s.cmd, s.host, s.rank), 1)
|
||||
if s.wait_accept > 0:
|
||||
wait_conn[rank] = s
|
||||
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank), 1)
|
||||
if s.wait_accept > 0:
|
||||
wait_conn[rank] = s
|
||||
self.log_print('@tracker All nodes finishes job', 2)
|
||||
self.end_time = time.time()
|
||||
self.log_print('@tracker %s secs between node start and job finish' % str(self.end_time - self.start_time), 2)
|
||||
|
||||
def submit(nslave, args, fun_submit, verbose, hostIP = 'auto'):
|
||||
master = Tracker(verbose = verbose, hostIP = hostIP)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
#!/usr/bin/python
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
This is a script to submit rabit job via Yarn
|
||||
rabit will run as a Yarn application
|
||||
@ -13,6 +13,7 @@ import rabit_tracker as tracker
|
||||
|
||||
WRAPPER_PATH = os.path.dirname(__file__) + '/../wrapper'
|
||||
YARN_JAR_PATH = os.path.dirname(__file__) + '/../yarn/rabit-yarn.jar'
|
||||
YARN_BOOT_PY = os.path.dirname(__file__) + '/../yarn/run_hdfs_prog.py'
|
||||
|
||||
if not os.path.exists(YARN_JAR_PATH):
|
||||
warnings.warn("cannot find \"%s\", I will try to run build" % YARN_JAR_PATH)
|
||||
@ -21,7 +22,7 @@ if not os.path.exists(YARN_JAR_PATH):
|
||||
subprocess.check_call(cmd, shell = True, env = os.environ)
|
||||
assert os.path.exists(YARN_JAR_PATH), "failed to build rabit-yarn.jar, try it manually"
|
||||
|
||||
hadoop_binary = 'hadoop'
|
||||
hadoop_binary = None
|
||||
# code
|
||||
hadoop_home = os.getenv('HADOOP_HOME')
|
||||
|
||||
@ -38,6 +39,8 @@ parser.add_argument('-hip', '--host_ip', default='auto', type=str,
|
||||
help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine')
|
||||
parser.add_argument('-v', '--verbose', default=0, choices=[0, 1], type=int,
|
||||
help = 'print more messages into the console')
|
||||
parser.add_argument('-q', '--queue', default='default', type=str,
|
||||
help = 'the queue we want to submit the job to')
|
||||
parser.add_argument('-ac', '--auto_file_cache', default=1, choices=[0, 1], type=int,
|
||||
help = 'whether automatically cache the files in the command to hadoop localfile, this is on by default')
|
||||
parser.add_argument('-f', '--files', default = [], action='append',
|
||||
@ -56,6 +59,11 @@ parser.add_argument('-mem', '--memory_mb', default=1024, type=int,
|
||||
help = 'maximum memory used by the process. Guide: set it large (near mapred.cluster.max.map.memory.mb)'\
|
||||
'if you are running multi-threading rabit,'\
|
||||
'so that each node can occupy all the mapper slots in a machine for maximum performance')
|
||||
parser.add_argument('--libhdfs-opts', default='-Xmx128m', type=str,
|
||||
help = 'setting to be passed to libhdfs')
|
||||
parser.add_argument('--name-node', default='default', type=str,
|
||||
help = 'the namenode address of hdfs, libhdfs should connect to, normally leave it as default')
|
||||
|
||||
parser.add_argument('command', nargs='+',
|
||||
help = 'command for rabit program')
|
||||
args = parser.parse_args()
|
||||
@ -87,7 +95,7 @@ if hadoop_version < 2:
|
||||
print 'Current Hadoop Version is %s, rabit_yarn will need Yarn(Hadoop 2.0)' % out[1]
|
||||
|
||||
def submit_yarn(nworker, worker_args, worker_env):
|
||||
fset = set([YARN_JAR_PATH])
|
||||
fset = set([YARN_JAR_PATH, YARN_BOOT_PY])
|
||||
if args.auto_file_cache != 0:
|
||||
for i in range(len(args.command)):
|
||||
f = args.command[i]
|
||||
@ -96,7 +104,7 @@ def submit_yarn(nworker, worker_args, worker_env):
|
||||
if i == 0:
|
||||
args.command[i] = './' + args.command[i].split('/')[-1]
|
||||
else:
|
||||
args.command[i] = args.command[i].split('/')[-1]
|
||||
args.command[i] = './' + args.command[i].split('/')[-1]
|
||||
if args.command[0].endswith('.py'):
|
||||
flst = [WRAPPER_PATH + '/rabit.py',
|
||||
WRAPPER_PATH + '/librabit_wrapper.so',
|
||||
@ -112,6 +120,8 @@ def submit_yarn(nworker, worker_args, worker_env):
|
||||
env['rabit_cpu_vcores'] = str(args.vcores)
|
||||
env['rabit_memory_mb'] = str(args.memory_mb)
|
||||
env['rabit_world_size'] = str(args.nworker)
|
||||
env['rabit_hdfs_opts'] = str(args.libhdfs_opts)
|
||||
env['rabit_hdfs_namenode'] = str(args.name_node)
|
||||
|
||||
if args.files != None:
|
||||
for flst in args.files:
|
||||
@ -121,7 +131,8 @@ def submit_yarn(nworker, worker_args, worker_env):
|
||||
cmd += ' -file %s' % f
|
||||
cmd += ' -jobname %s ' % args.jobname
|
||||
cmd += ' -tempdir %s ' % args.tempdir
|
||||
cmd += (' '.join(args.command + worker_args))
|
||||
cmd += ' -queue %s ' % args.queue
|
||||
cmd += (' '.join(['./run_hdfs_prog.py'] + args.command + worker_args))
|
||||
if args.verbose != 0:
|
||||
print cmd
|
||||
subprocess.check_call(cmd, shell = True, env = env)
|
||||
|
||||
@ -1 +0,0 @@
|
||||
foler used to hold generated class files
|
||||
@ -1,4 +1,8 @@
|
||||
#!/bin/bash
|
||||
CPATH=`${HADOOP_PREFIX}/bin/hadoop classpath`
|
||||
if [ ! -d bin ]; then
|
||||
mkdir bin
|
||||
fi
|
||||
|
||||
CPATH=`${HADOOP_HOME}/bin/hadoop classpath`
|
||||
javac -cp $CPATH -d bin src/org/apache/hadoop/yarn/rabit/*
|
||||
jar cf rabit-yarn.jar -C bin .
|
||||
|
||||
45
yarn/run_hdfs_prog.py
Executable file
45
yarn/run_hdfs_prog.py
Executable file
@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
this script helps setup classpath env for HDFS, before running program
|
||||
that links with libhdfs
|
||||
"""
|
||||
import glob
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print 'Usage: the command you want to run'
|
||||
|
||||
hadoop_home = os.getenv('HADOOP_HOME')
|
||||
hdfs_home = os.getenv('HADOOP_HDFS_HOME')
|
||||
java_home = os.getenv('JAVA_HOME')
|
||||
if hadoop_home is None:
|
||||
hadoop_home = os.getenv('HADOOP_PREFIX')
|
||||
assert hadoop_home is not None, 'need to set HADOOP_HOME'
|
||||
assert hdfs_home is not None, 'need to set HADOOP_HDFS_HOME'
|
||||
assert java_home is not None, 'need to set JAVA_HOME'
|
||||
|
||||
(classpath, err) = subprocess.Popen('%s/bin/hadoop classpath' % hadoop_home,
|
||||
stdout=subprocess.PIPE, shell = True,
|
||||
env = os.environ).communicate()
|
||||
cpath = []
|
||||
for f in classpath.split(':'):
|
||||
cpath += glob.glob(f)
|
||||
|
||||
lpath = []
|
||||
lpath.append('%s/lib/native' % hdfs_home)
|
||||
lpath.append('%s/jre/lib/amd64/server' % java_home)
|
||||
|
||||
env = os.environ.copy()
|
||||
env['CLASSPATH'] = '${CLASSPATH}:' + (':'.join(cpath))
|
||||
|
||||
# setup hdfs options
|
||||
if 'rabit_hdfs_opts' in env:
|
||||
env['LIBHDFS_OPTS'] = env['rabit_hdfs_opts']
|
||||
elif 'LIBHDFS_OPTS' not in env:
|
||||
env['LIBHDFS_OPTS'] = '--Xmx128m'
|
||||
|
||||
env['LD_LIBRARY_PATH'] = '${LD_LIBRARY_PATH}:' + (':'.join(lpath))
|
||||
ret = subprocess.call(args = sys.argv[1:], env = env)
|
||||
sys.exit(ret)
|
||||
@ -1,5 +1,6 @@
|
||||
package org.apache.hadoop.yarn.rabit;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.List;
|
||||
@ -7,6 +8,7 @@ import java.util.Map;
|
||||
import java.util.Queue;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
@ -34,6 +36,7 @@ import org.apache.hadoop.yarn.api.records.NodeReport;
|
||||
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest;
|
||||
import org.apache.hadoop.yarn.client.api.async.NMClientAsync;
|
||||
import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
|
||||
import org.apache.hadoop.security.UserGroupInformation;
|
||||
|
||||
/**
|
||||
* application master for allocating resources of rabit client
|
||||
@ -61,6 +64,8 @@ public class ApplicationMaster {
|
||||
// command to launch
|
||||
private String command = "";
|
||||
|
||||
// username
|
||||
private String userName = "";
|
||||
// application tracker hostname
|
||||
private String appHostName = "";
|
||||
// tracker URL to do
|
||||
@ -128,6 +133,8 @@ public class ApplicationMaster {
|
||||
*/
|
||||
private void initArgs(String args[]) throws IOException {
|
||||
LOG.info("Invoke initArgs");
|
||||
// get user name
|
||||
userName = UserGroupInformation.getCurrentUser().getShortUserName();
|
||||
// cached maps
|
||||
Map<String, Path> cacheFiles = new java.util.HashMap<String, Path>();
|
||||
for (int i = 0; i < args.length; ++i) {
|
||||
@ -156,7 +163,8 @@ public class ApplicationMaster {
|
||||
numVCores = this.getEnvInteger("rabit_cpu_vcores", true, numVCores);
|
||||
numMemoryMB = this.getEnvInteger("rabit_memory_mb", true, numMemoryMB);
|
||||
numTasks = this.getEnvInteger("rabit_world_size", true, numTasks);
|
||||
maxNumAttempt = this.getEnvInteger("rabit_max_attempt", false, maxNumAttempt);
|
||||
maxNumAttempt = this.getEnvInteger("rabit_max_attempt", false,
|
||||
maxNumAttempt);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -175,7 +183,7 @@ public class ApplicationMaster {
|
||||
RegisterApplicationMasterResponse response = this.rmClient
|
||||
.registerApplicationMaster(this.appHostName,
|
||||
this.appTrackerPort, this.appTrackerUrl);
|
||||
|
||||
|
||||
boolean success = false;
|
||||
String diagnostics = "";
|
||||
try {
|
||||
@ -208,19 +216,20 @@ public class ApplicationMaster {
|
||||
assert (killedTasks.size() + finishedTasks.size() == numTasks);
|
||||
success = finishedTasks.size() == numTasks;
|
||||
LOG.info("Application completed. Stopping running containers");
|
||||
nmClient.stop();
|
||||
diagnostics = "Diagnostics." + ", num_tasks" + this.numTasks
|
||||
+ ", finished=" + this.finishedTasks.size() + ", failed="
|
||||
+ this.killedTasks.size() + "\n" + this.abortDiagnosis;
|
||||
+ ", finished=" + this.finishedTasks.size() + ", failed="
|
||||
+ this.killedTasks.size() + "\n" + this.abortDiagnosis;
|
||||
nmClient.stop();
|
||||
LOG.info(diagnostics);
|
||||
} catch (Exception e) {
|
||||
diagnostics = e.toString();
|
||||
}
|
||||
}
|
||||
rmClient.unregisterApplicationMaster(
|
||||
success ? FinalApplicationStatus.SUCCEEDED
|
||||
: FinalApplicationStatus.FAILED, diagnostics,
|
||||
appTrackerUrl);
|
||||
if (!success) throw new Exception("Application not successful");
|
||||
if (!success)
|
||||
throw new Exception("Application not successful");
|
||||
}
|
||||
|
||||
/**
|
||||
@ -265,30 +274,63 @@ public class ApplicationMaster {
|
||||
task.containerRequest = null;
|
||||
ContainerLaunchContext ctx = Records
|
||||
.newRecord(ContainerLaunchContext.class);
|
||||
String cmd =
|
||||
// use this to setup CLASSPATH correctly for libhdfs
|
||||
"CLASSPATH=${CLASSPATH}:`${HADOOP_PREFIX}/bin/hadoop classpath --glob` "
|
||||
+ this.command + " 1>"
|
||||
+ ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
|
||||
+ " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR
|
||||
+ "/stderr";
|
||||
LOG.info(cmd);
|
||||
String hadoop = "hadoop";
|
||||
if (System.getenv("HADOOP_HOME") != null) {
|
||||
hadoop = "${HADOOP_HOME}/bin/hadoop";
|
||||
} else if (System.getenv("HADOOP_PREFIX") != null) {
|
||||
hadoop = "${HADOOP_PREFIX}/bin/hadoop";
|
||||
}
|
||||
|
||||
String cmd =
|
||||
// use this to setup CLASSPATH correctly for libhdfs
|
||||
this.command + " 1>"
|
||||
+ ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
|
||||
+ " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR
|
||||
+ "/stderr";
|
||||
ctx.setCommands(Collections.singletonList(cmd));
|
||||
LOG.info(workerResources);
|
||||
ctx.setLocalResources(this.workerResources);
|
||||
// setup environment variables
|
||||
Map<String, String> env = new java.util.HashMap<String, String>();
|
||||
|
||||
|
||||
// setup class path, this is kind of duplicated, ignoring
|
||||
StringBuilder cpath = new StringBuilder("${CLASSPATH}:./*");
|
||||
for (String c : conf.getStrings(
|
||||
YarnConfiguration.YARN_APPLICATION_CLASSPATH,
|
||||
YarnConfiguration.DEFAULT_YARN_APPLICATION_CLASSPATH)) {
|
||||
cpath.append(':');
|
||||
cpath.append(c.trim());
|
||||
String[] arrPath = c.split(":");
|
||||
for (String ps : arrPath) {
|
||||
if (ps.endsWith("*.jar") || ps.endsWith("*")) {
|
||||
ps = ps.substring(0, ps.lastIndexOf('*'));
|
||||
String prefix = ps.substring(0, ps.lastIndexOf('/'));
|
||||
if (ps.startsWith("$")) {
|
||||
String[] arr =ps.split("/", 2);
|
||||
if (arr.length != 2) continue;
|
||||
try {
|
||||
ps = System.getenv(arr[0].substring(1)) + '/' + arr[1];
|
||||
} catch (Exception e){
|
||||
continue;
|
||||
}
|
||||
}
|
||||
File dir = new File(ps);
|
||||
if (dir.isDirectory()) {
|
||||
for (File f: dir.listFiles()) {
|
||||
if (f.isFile() && f.getPath().endsWith(".jar")) {
|
||||
cpath.append(":");
|
||||
cpath.append(prefix + '/' + f.getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cpath.append(':');
|
||||
cpath.append(ps.trim());
|
||||
}
|
||||
}
|
||||
}
|
||||
// already use hadoop command to get class path in worker, maybe a better solution in future
|
||||
// env.put("CLASSPATH", cpath.toString());
|
||||
// already use hadoop command to get class path in worker, maybe a
|
||||
// better solution in future
|
||||
env.put("CLASSPATH", cpath.toString());
|
||||
//LOG.info("CLASSPATH =" + cpath.toString());
|
||||
// setup LD_LIBARY_PATH path for libhdfs
|
||||
env.put("LD_LIBRARY_PATH",
|
||||
"${LD_LIBRARY_PATH}:$HADOOP_HDFS_HOME/lib/native:$JAVA_HOME/jre/lib/amd64/server");
|
||||
@ -298,10 +340,13 @@ public class ApplicationMaster {
|
||||
if (e.getKey().startsWith("rabit_")) {
|
||||
env.put(e.getKey(), e.getValue());
|
||||
}
|
||||
if (e.getKey() == "LIBHDFS_OPTS") {
|
||||
env.put(e.getKey(), e.getValue());
|
||||
}
|
||||
}
|
||||
env.put("rabit_task_id", String.valueOf(task.taskId));
|
||||
env.put("rabit_num_trial", String.valueOf(task.attemptCounter));
|
||||
|
||||
// ctx.setUser(userName);
|
||||
ctx.setEnvironment(env);
|
||||
synchronized (this) {
|
||||
assert (!this.runningTasks.containsKey(container.getId()));
|
||||
@ -376,8 +421,17 @@ public class ApplicationMaster {
|
||||
Collection<TaskRecord> tasks = new java.util.LinkedList<TaskRecord>();
|
||||
for (ContainerId cid : failed) {
|
||||
TaskRecord r = runningTasks.remove(cid);
|
||||
if (r == null)
|
||||
if (r == null) {
|
||||
continue;
|
||||
}
|
||||
LOG.info("Task "
|
||||
+ r.taskId
|
||||
+ "failed on "
|
||||
+ r.container.getId()
|
||||
+ ". See LOG at : "
|
||||
+ String.format("http://%s/node/containerlogs/%s/"
|
||||
+ userName, r.container.getNodeHttpAddress(),
|
||||
r.container.getId()));
|
||||
r.attemptCounter += 1;
|
||||
r.container = null;
|
||||
tasks.add(r);
|
||||
@ -411,22 +465,26 @@ public class ApplicationMaster {
|
||||
finishedTasks.add(r);
|
||||
runningTasks.remove(s.getContainerId());
|
||||
} else {
|
||||
switch (exstatus) {
|
||||
case ContainerExitStatus.KILLED_EXCEEDED_PMEM:
|
||||
this.abortJob("[Rabit] Task "
|
||||
+ r.taskId
|
||||
+ " killed because of exceeding allocated physical memory");
|
||||
break;
|
||||
case ContainerExitStatus.KILLED_EXCEEDED_VMEM:
|
||||
this.abortJob("[Rabit] Task "
|
||||
+ r.taskId
|
||||
+ " killed because of exceeding allocated virtual memory");
|
||||
break;
|
||||
default:
|
||||
LOG.info("[Rabit] Task " + r.taskId
|
||||
+ " exited with status " + exstatus);
|
||||
failed.add(s.getContainerId());
|
||||
try {
|
||||
if (exstatus == ContainerExitStatus.class.getField(
|
||||
"KILLED_EXCEEDED_PMEM").getInt(null)) {
|
||||
this.abortJob("[Rabit] Task "
|
||||
+ r.taskId
|
||||
+ " killed because of exceeding allocated physical memory");
|
||||
continue;
|
||||
}
|
||||
if (exstatus == ContainerExitStatus.class.getField(
|
||||
"KILLED_EXCEEDED_VMEM").getInt(null)) {
|
||||
this.abortJob("[Rabit] Task "
|
||||
+ r.taskId
|
||||
+ " killed because of exceeding allocated virtual memory");
|
||||
continue;
|
||||
}
|
||||
} catch (Exception e) {
|
||||
}
|
||||
LOG.info("[Rabit] Task " + r.taskId + " exited with status "
|
||||
+ exstatus + " Diagnostics:"+ s.getDiagnostics());
|
||||
failed.add(s.getContainerId());
|
||||
}
|
||||
}
|
||||
this.handleFailure(failed);
|
||||
|
||||
@ -9,6 +9,7 @@ import org.apache.hadoop.fs.Path;
|
||||
import org.apache.hadoop.fs.FileStatus;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.permission.FsPermission;
|
||||
import org.apache.hadoop.security.UserGroupInformation;
|
||||
import org.apache.hadoop.yarn.api.ApplicationConstants;
|
||||
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
||||
import org.apache.hadoop.yarn.api.records.ApplicationReport;
|
||||
@ -19,6 +20,7 @@ import org.apache.hadoop.yarn.api.records.LocalResource;
|
||||
import org.apache.hadoop.yarn.api.records.LocalResourceType;
|
||||
import org.apache.hadoop.yarn.api.records.LocalResourceVisibility;
|
||||
import org.apache.hadoop.yarn.api.records.Resource;
|
||||
import org.apache.hadoop.yarn.api.records.QueueInfo;
|
||||
import org.apache.hadoop.yarn.api.records.YarnApplicationState;
|
||||
import org.apache.hadoop.yarn.client.api.YarnClient;
|
||||
import org.apache.hadoop.yarn.client.api.YarnClientApplication;
|
||||
@ -43,14 +45,21 @@ public class Client {
|
||||
private String appArgs = "";
|
||||
// HDFS Path to store temporal result
|
||||
private String tempdir = "/tmp";
|
||||
// user name
|
||||
private String userName = "";
|
||||
// job name
|
||||
private String jobName = "";
|
||||
// queue
|
||||
private String queue = "default";
|
||||
/**
|
||||
* constructor
|
||||
* @throws IOException
|
||||
*/
|
||||
private Client() throws IOException {
|
||||
conf.addResource(new Path(System.getenv("HADOOP_CONF_DIR") +"/core-site.xml"));
|
||||
conf.addResource(new Path(System.getenv("HADOOP_CONF_DIR") +"/hdfs-site.xml"));
|
||||
dfs = FileSystem.get(conf);
|
||||
userName = UserGroupInformation.getCurrentUser().getShortUserName();
|
||||
}
|
||||
|
||||
/**
|
||||
@ -127,6 +136,9 @@ public class Client {
|
||||
if (e.getKey().startsWith("rabit_")) {
|
||||
env.put(e.getKey(), e.getValue());
|
||||
}
|
||||
if (e.getKey() == "LIBHDFS_OPTS") {
|
||||
env.put(e.getKey(), e.getValue());
|
||||
}
|
||||
}
|
||||
LOG.debug(env);
|
||||
return env;
|
||||
@ -152,6 +164,8 @@ public class Client {
|
||||
this.jobName = args[++i];
|
||||
} else if(args[i].equals("-tempdir")) {
|
||||
this.tempdir = args[++i];
|
||||
} else if(args[i].equals("-queue")) {
|
||||
this.queue = args[++i];
|
||||
} else {
|
||||
sargs.append(" ");
|
||||
sargs.append(args[i]);
|
||||
@ -168,7 +182,6 @@ public class Client {
|
||||
}
|
||||
this.initArgs(args);
|
||||
// Create yarnClient
|
||||
YarnConfiguration conf = new YarnConfiguration();
|
||||
YarnClient yarnClient = YarnClient.createYarnClient();
|
||||
yarnClient.init(conf);
|
||||
yarnClient.start();
|
||||
@ -181,13 +194,14 @@ public class Client {
|
||||
.newRecord(ContainerLaunchContext.class);
|
||||
ApplicationSubmissionContext appContext = app
|
||||
.getApplicationSubmissionContext();
|
||||
|
||||
// Submit application
|
||||
ApplicationId appId = appContext.getApplicationId();
|
||||
// setup cache-files and environment variables
|
||||
amContainer.setLocalResources(this.setupCacheFiles(appId));
|
||||
amContainer.setEnvironment(this.getEnvironment());
|
||||
String cmd = "$JAVA_HOME/bin/java"
|
||||
+ " -Xmx256M"
|
||||
+ " -Xmx900M"
|
||||
+ " org.apache.hadoop.yarn.rabit.ApplicationMaster"
|
||||
+ this.cacheFileArg + ' ' + this.appArgs + " 1>"
|
||||
+ ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
|
||||
@ -197,15 +211,15 @@ public class Client {
|
||||
|
||||
// Set up resource type requirements for ApplicationMaster
|
||||
Resource capability = Records.newRecord(Resource.class);
|
||||
capability.setMemory(256);
|
||||
capability.setMemory(1024);
|
||||
capability.setVirtualCores(1);
|
||||
LOG.info("jobname=" + this.jobName);
|
||||
|
||||
LOG.info("jobname=" + this.jobName + ",username=" + this.userName);
|
||||
|
||||
appContext.setApplicationName(jobName + ":RABIT-YARN");
|
||||
appContext.setAMContainerSpec(amContainer);
|
||||
appContext.setResource(capability);
|
||||
appContext.setQueue("default");
|
||||
|
||||
appContext.setQueue(queue);
|
||||
//appContext.setUser(userName);
|
||||
LOG.info("Submitting application " + appId);
|
||||
yarnClient.submitApplication(appContext);
|
||||
|
||||
@ -218,12 +232,16 @@ public class Client {
|
||||
appReport = yarnClient.getApplicationReport(appId);
|
||||
appState = appReport.getYarnApplicationState();
|
||||
}
|
||||
|
||||
|
||||
System.out.println("Application " + appId + " finished with"
|
||||
+ " state " + appState + " at " + appReport.getFinishTime());
|
||||
if (!appReport.getFinalApplicationStatus().equals(
|
||||
FinalApplicationStatus.SUCCEEDED)) {
|
||||
System.err.println(appReport.getDiagnostics());
|
||||
System.out.println("Available queues:");
|
||||
for (QueueInfo q : yarnClient.getAllQueues()) {
|
||||
System.out.println(q.getQueueName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user