Merge commit '75bf97b57539e5572e7ae8eba72bac6562c63c07'
Conflicts: subtree/rabit/rabit-learn/io/line_split-inl.h subtree/rabit/yarn/build.sh
This commit is contained in:
commit
9ccbeaa8f0
@ -19,6 +19,8 @@ namespace utils {
|
|||||||
/*! \brief interface of i/o stream that support seek */
|
/*! \brief interface of i/o stream that support seek */
|
||||||
class ISeekStream: public IStream {
|
class ISeekStream: public IStream {
|
||||||
public:
|
public:
|
||||||
|
// virtual destructor
|
||||||
|
virtual ~ISeekStream(void) {}
|
||||||
/*! \brief seek to certain position of the file */
|
/*! \brief seek to certain position of the file */
|
||||||
virtual void Seek(size_t pos) = 0;
|
virtual void Seek(size_t pos) = 0;
|
||||||
/*! \brief tell the position of the stream */
|
/*! \brief tell the position of the stream */
|
||||||
|
|||||||
2
subtree/rabit/rabit-learn/.gitignore
vendored
Normal file
2
subtree/rabit/rabit-learn/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
config.mk
|
||||||
|
*.log
|
||||||
@ -38,6 +38,7 @@ class StreamBufferReader {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/*! \brief whether we are reaching the end of file */
|
||||||
inline bool AtEnd(void) const {
|
inline bool AtEnd(void) const {
|
||||||
return read_len_ == 0;
|
return read_len_ == 0;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -66,27 +66,36 @@ class FileStream : public utils::ISeekStream {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief line split from normal file system */
|
/*! \brief line split from normal file system */
|
||||||
class FileSplit : public LineSplitBase {
|
class FileProvider : public LineSplitter::IFileProvider {
|
||||||
public:
|
public:
|
||||||
explicit FileSplit(const char *uri, unsigned rank, unsigned nsplit) {
|
explicit FileProvider(const char *uri) {
|
||||||
LineSplitBase::SplitNames(&fnames_, uri, "#");
|
LineSplitter::SplitNames(&fnames_, uri, "#");
|
||||||
std::vector<size_t> fsize;
|
std::vector<size_t> fsize;
|
||||||
for (size_t i = 0; i < fnames_.size(); ++i) {
|
for (size_t i = 0; i < fnames_.size(); ++i) {
|
||||||
if (!std::strncmp(fnames_[i].c_str(), "file://", 7)) {
|
if (!std::strncmp(fnames_[i].c_str(), "file://", 7)) {
|
||||||
std::string tmp = fnames_[i].c_str() + 7;
|
std::string tmp = fnames_[i].c_str() + 7;
|
||||||
fnames_[i] = tmp;
|
fnames_[i] = tmp;
|
||||||
}
|
}
|
||||||
fsize.push_back(GetFileSize(fnames_[i].c_str()));
|
size_t fz = GetFileSize(fnames_[i].c_str());
|
||||||
|
if (fz != 0) {
|
||||||
|
fsize_.push_back(fz);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
LineSplitBase::Init(fsize, rank, nsplit);
|
|
||||||
}
|
}
|
||||||
virtual ~FileSplit(void) {}
|
// destrucor
|
||||||
|
virtual ~FileProvider(void) {}
|
||||||
protected:
|
virtual utils::ISeekStream *Open(size_t file_index) {
|
||||||
virtual utils::ISeekStream *GetFile(size_t file_index) {
|
|
||||||
utils::Assert(file_index < fnames_.size(), "file index exceed bound");
|
utils::Assert(file_index < fnames_.size(), "file index exceed bound");
|
||||||
return new FileStream(fnames_[file_index].c_str(), "rb");
|
return new FileStream(fnames_[file_index].c_str(), "rb");
|
||||||
}
|
}
|
||||||
|
virtual const std::vector<size_t> &FileSize(void) const {
|
||||||
|
return fsize_;
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
// file sizes
|
||||||
|
std::vector<size_t> fsize_;
|
||||||
|
// file names
|
||||||
|
std::vector<std::string> fnames_;
|
||||||
// get file size
|
// get file size
|
||||||
inline static size_t GetFileSize(const char *fname) {
|
inline static size_t GetFileSize(const char *fname) {
|
||||||
std::FILE *fp = utils::FopenCheck(fname, "rb");
|
std::FILE *fp = utils::FopenCheck(fname, "rb");
|
||||||
@ -96,10 +105,6 @@ class FileSplit : public LineSplitBase {
|
|||||||
std::fclose(fp);
|
std::fclose(fp);
|
||||||
return fsize;
|
return fsize;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
// file names
|
|
||||||
std::vector<std::string> fnames_;
|
|
||||||
};
|
};
|
||||||
} // namespace io
|
} // namespace io
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <cstdlib>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <hdfs.h>
|
#include <hdfs.h>
|
||||||
#include <errno.h>
|
#include <errno.h>
|
||||||
@ -15,11 +16,15 @@
|
|||||||
/*! \brief io interface */
|
/*! \brief io interface */
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace io {
|
namespace io {
|
||||||
class HDFSStream : public utils::ISeekStream {
|
class HDFSStream : public ISeekStream {
|
||||||
public:
|
public:
|
||||||
HDFSStream(hdfsFS fs, const char *fname, const char *mode)
|
HDFSStream(hdfsFS fs,
|
||||||
: fs_(fs), at_end_(false) {
|
const char *fname,
|
||||||
int flag;
|
const char *mode,
|
||||||
|
bool disconnect_when_done)
|
||||||
|
: fs_(fs), at_end_(false),
|
||||||
|
disconnect_when_done_(disconnect_when_done) {
|
||||||
|
int flag = 0;
|
||||||
if (!strcmp(mode, "r")) {
|
if (!strcmp(mode, "r")) {
|
||||||
flag = O_RDONLY;
|
flag = O_RDONLY;
|
||||||
} else if (!strcmp(mode, "w")) {
|
} else if (!strcmp(mode, "w")) {
|
||||||
@ -35,6 +40,9 @@ class HDFSStream : public utils::ISeekStream {
|
|||||||
}
|
}
|
||||||
virtual ~HDFSStream(void) {
|
virtual ~HDFSStream(void) {
|
||||||
this->Close();
|
this->Close();
|
||||||
|
if (disconnect_when_done_) {
|
||||||
|
utils::Check(hdfsDisconnect(fs_) == 0, "hdfsDisconnect error");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
virtual size_t Read(void *ptr, size_t size) {
|
virtual size_t Read(void *ptr, size_t size) {
|
||||||
tSize nread = hdfsRead(fs_, fp_, ptr, size);
|
tSize nread = hdfsRead(fs_, fp_, ptr, size);
|
||||||
@ -86,52 +94,69 @@ class HDFSStream : public utils::ISeekStream {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline static std::string GetNameNode(void) {
|
||||||
|
const char *nn = getenv("rabit_hdfs_namenode");
|
||||||
|
if (nn == NULL) {
|
||||||
|
return std::string("default");
|
||||||
|
} else {
|
||||||
|
return std::string(nn);
|
||||||
|
}
|
||||||
|
}
|
||||||
private:
|
private:
|
||||||
hdfsFS fs_;
|
hdfsFS fs_;
|
||||||
hdfsFile fp_;
|
hdfsFile fp_;
|
||||||
bool at_end_;
|
bool at_end_;
|
||||||
|
bool disconnect_when_done_;
|
||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief line split from normal file system */
|
/*! \brief line split from normal file system */
|
||||||
class HDFSSplit : public LineSplitBase {
|
class HDFSProvider : public LineSplitter::IFileProvider {
|
||||||
public:
|
public:
|
||||||
explicit HDFSSplit(const char *uri, unsigned rank, unsigned nsplit) {
|
explicit HDFSProvider(const char *uri) {
|
||||||
fs_ = hdfsConnect("default", 0);
|
fs_ = hdfsConnect(HDFSStream::GetNameNode().c_str(), 0);
|
||||||
|
utils::Check(fs_ != NULL, "error when connecting to default HDFS");
|
||||||
std::vector<std::string> paths;
|
std::vector<std::string> paths;
|
||||||
LineSplitBase::SplitNames(&paths, uri, "#");
|
LineSplitter::SplitNames(&paths, uri, "#");
|
||||||
// get the files
|
// get the files
|
||||||
std::vector<size_t> fsize;
|
|
||||||
for (size_t i = 0; i < paths.size(); ++i) {
|
for (size_t i = 0; i < paths.size(); ++i) {
|
||||||
hdfsFileInfo *info = hdfsGetPathInfo(fs_, paths[i].c_str());
|
hdfsFileInfo *info = hdfsGetPathInfo(fs_, paths[i].c_str());
|
||||||
|
utils::Check(info != NULL, "path %s do not exist", paths[i].c_str());
|
||||||
if (info->mKind == 'D') {
|
if (info->mKind == 'D') {
|
||||||
int nentry;
|
int nentry;
|
||||||
hdfsFileInfo *files = hdfsListDirectory(fs_, info->mName, &nentry);
|
hdfsFileInfo *files = hdfsListDirectory(fs_, info->mName, &nentry);
|
||||||
|
utils::Check(files != NULL, "error when ListDirectory %s", info->mName);
|
||||||
for (int i = 0; i < nentry; ++i) {
|
for (int i = 0; i < nentry; ++i) {
|
||||||
if (files[i].mKind == 'F') {
|
if (files[i].mKind == 'F' && files[i].mSize != 0) {
|
||||||
fsize.push_back(files[i].mSize);
|
fsize_.push_back(files[i].mSize);
|
||||||
fnames_.push_back(std::string(files[i].mName));
|
fnames_.push_back(std::string(files[i].mName));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
hdfsFreeFileInfo(files, nentry);
|
hdfsFreeFileInfo(files, nentry);
|
||||||
} else {
|
} else {
|
||||||
fsize.push_back(info->mSize);
|
if (info->mSize != 0) {
|
||||||
fnames_.push_back(std::string(info->mName));
|
fsize_.push_back(info->mSize);
|
||||||
|
fnames_.push_back(std::string(info->mName));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
hdfsFreeFileInfo(info, 1);
|
hdfsFreeFileInfo(info, 1);
|
||||||
}
|
}
|
||||||
LineSplitBase::Init(fsize, rank, nsplit);
|
|
||||||
}
|
}
|
||||||
virtual ~HDFSSplit(void) {}
|
virtual ~HDFSProvider(void) {
|
||||||
|
utils::Check(hdfsDisconnect(fs_) == 0, "hdfsDisconnect error");
|
||||||
protected:
|
}
|
||||||
virtual utils::ISeekStream *GetFile(size_t file_index) {
|
virtual const std::vector<size_t> &FileSize(void) const {
|
||||||
|
return fsize_;
|
||||||
|
}
|
||||||
|
virtual ISeekStream *Open(size_t file_index) {
|
||||||
utils::Assert(file_index < fnames_.size(), "file index exceed bound");
|
utils::Assert(file_index < fnames_.size(), "file index exceed bound");
|
||||||
return new HDFSStream(fs_, fnames_[file_index].c_str(), "r");
|
return new HDFSStream(fs_, fnames_[file_index].c_str(), "r", false);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// hdfs handle
|
// hdfs handle
|
||||||
hdfsFS fs_;
|
hdfsFS fs_;
|
||||||
|
// file sizes
|
||||||
|
std::vector<size_t> fsize_;
|
||||||
// file names
|
// file names
|
||||||
std::vector<std::string> fnames_;
|
std::vector<std::string> fnames_;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -30,16 +30,16 @@ inline InputSplit *CreateInputSplit(const char *uri,
|
|||||||
return new SingleFileSplit(uri);
|
return new SingleFileSplit(uri);
|
||||||
}
|
}
|
||||||
if (!strncmp(uri, "file://", 7)) {
|
if (!strncmp(uri, "file://", 7)) {
|
||||||
return new FileSplit(uri, part, nsplit);
|
return new LineSplitter(new FileProvider(uri), part, nsplit);
|
||||||
}
|
}
|
||||||
if (!strncmp(uri, "hdfs://", 7)) {
|
if (!strncmp(uri, "hdfs://", 7)) {
|
||||||
#if RABIT_USE_HDFS
|
#if RABIT_USE_HDFS
|
||||||
return new HDFSSplit(uri, part, nsplit);
|
return new LineSplitter(new HDFSProvider(uri), part, nsplit);
|
||||||
#else
|
#else
|
||||||
utils::Error("Please compile with RABIT_USE_HDFS=1");
|
utils::Error("Please compile with RABIT_USE_HDFS=1");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
return new FileSplit(uri, part, nsplit);
|
return new LineSplitter(new FileProvider(uri), part, nsplit);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief create an stream, the stream must be able to close
|
* \brief create an stream, the stream must be able to close
|
||||||
@ -55,7 +55,8 @@ inline IStream *CreateStream(const char *uri, const char *mode) {
|
|||||||
}
|
}
|
||||||
if (!strncmp(uri, "hdfs://", 7)) {
|
if (!strncmp(uri, "hdfs://", 7)) {
|
||||||
#if RABIT_USE_HDFS
|
#if RABIT_USE_HDFS
|
||||||
return new HDFSStream(hdfsConnect("default", 0), uri, mode);
|
return new HDFSStream(hdfsConnect(HDFSStream::GetNameNode().c_str(), 0),
|
||||||
|
uri, mode, true);
|
||||||
#else
|
#else
|
||||||
utils::Error("Please compile with RABIT_USE_HDFS=1");
|
utils::Error("Please compile with RABIT_USE_HDFS=1");
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -19,6 +19,7 @@ namespace rabit {
|
|||||||
* \brief namespace to handle input split and filesystem interfacing
|
* \brief namespace to handle input split and filesystem interfacing
|
||||||
*/
|
*/
|
||||||
namespace io {
|
namespace io {
|
||||||
|
/*! \brief reused ISeekStream's definition */
|
||||||
typedef utils::ISeekStream ISeekStream;
|
typedef utils::ISeekStream ISeekStream;
|
||||||
/*!
|
/*!
|
||||||
* \brief user facing input split helper,
|
* \brief user facing input split helper,
|
||||||
|
|||||||
@ -15,11 +15,42 @@
|
|||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace io {
|
namespace io {
|
||||||
class LineSplitBase : public InputSplit {
|
|
||||||
|
/*! \brief class that split the files by line */
|
||||||
|
class LineSplitter : public InputSplit {
|
||||||
public:
|
public:
|
||||||
virtual ~LineSplitBase() {
|
class IFileProvider {
|
||||||
if (fs_ != NULL) delete fs_;
|
public:
|
||||||
|
/*!
|
||||||
|
* \brief get the seek stream of given file_index
|
||||||
|
* \return the corresponding seek stream at head of the stream
|
||||||
|
* the seek stream's resource can be freed by calling delete
|
||||||
|
*/
|
||||||
|
virtual ISeekStream *Open(size_t file_index) = 0;
|
||||||
|
/*!
|
||||||
|
* \return const reference to size of each files
|
||||||
|
*/
|
||||||
|
virtual const std::vector<size_t> &FileSize(void) const = 0;
|
||||||
|
// virtual destructor
|
||||||
|
virtual ~IFileProvider() {}
|
||||||
|
};
|
||||||
|
// constructor
|
||||||
|
explicit LineSplitter(IFileProvider *provider,
|
||||||
|
unsigned rank,
|
||||||
|
unsigned nsplit)
|
||||||
|
: provider_(provider), fs_(NULL),
|
||||||
|
reader_(kBufferSize) {
|
||||||
|
this->Init(provider_->FileSize(), rank, nsplit);
|
||||||
}
|
}
|
||||||
|
// destructor
|
||||||
|
virtual ~LineSplitter() {
|
||||||
|
if (fs_ != NULL) {
|
||||||
|
delete fs_; fs_ = NULL;
|
||||||
|
}
|
||||||
|
// delete provider after destructing the streams
|
||||||
|
delete provider_;
|
||||||
|
}
|
||||||
|
// get next line
|
||||||
virtual bool NextLine(std::string *out_data) {
|
virtual bool NextLine(std::string *out_data) {
|
||||||
if (file_ptr_ >= file_ptr_end_ &&
|
if (file_ptr_ >= file_ptr_end_ &&
|
||||||
offset_curr_ >= offset_end_) return false;
|
offset_curr_ >= offset_end_) return false;
|
||||||
@ -29,15 +60,15 @@ class LineSplitBase : public InputSplit {
|
|||||||
if (reader_.AtEnd()) {
|
if (reader_.AtEnd()) {
|
||||||
if (out_data->length() != 0) return true;
|
if (out_data->length() != 0) return true;
|
||||||
file_ptr_ += 1;
|
file_ptr_ += 1;
|
||||||
|
if (offset_curr_ >= offset_end_) return false;
|
||||||
if (offset_curr_ != file_offset_[file_ptr_]) {
|
if (offset_curr_ != file_offset_[file_ptr_]) {
|
||||||
utils::Error("warning:std::FILE size not calculated correctly\n");
|
utils::Error("warning: FILE size not calculated correctly\n");
|
||||||
offset_curr_ = file_offset_[file_ptr_];
|
offset_curr_ = file_offset_[file_ptr_];
|
||||||
}
|
}
|
||||||
if (offset_curr_ >= offset_end_) return false;
|
|
||||||
utils::Assert(file_ptr_ + 1 < file_offset_.size(),
|
utils::Assert(file_ptr_ + 1 < file_offset_.size(),
|
||||||
"boundary check");
|
"boundary check");
|
||||||
delete fs_;
|
delete fs_;
|
||||||
fs_ = this->GetFile(file_ptr_);
|
fs_ = provider_->Open(file_ptr_);
|
||||||
reader_.set_stream(fs_);
|
reader_.set_stream(fs_);
|
||||||
} else {
|
} else {
|
||||||
++offset_curr_;
|
++offset_curr_;
|
||||||
@ -51,15 +82,27 @@ class LineSplitBase : public InputSplit {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/*!
|
||||||
protected:
|
* \brief split names given
|
||||||
// constructor
|
* \param out_fname output std::FILE names
|
||||||
LineSplitBase(void)
|
* \param uri_ the iput uri std::FILE
|
||||||
: fs_(NULL), reader_(kBufferSize) {
|
* \param dlm deliminetr
|
||||||
|
*/
|
||||||
|
inline static void SplitNames(std::vector<std::string> *out_fname,
|
||||||
|
const char *uri_,
|
||||||
|
const char *dlm) {
|
||||||
|
std::string uri = uri_;
|
||||||
|
char *p = std::strtok(BeginPtr(uri), dlm);
|
||||||
|
while (p != NULL) {
|
||||||
|
out_fname->push_back(std::string(p));
|
||||||
|
p = std::strtok(NULL, dlm);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
/*!
|
/*!
|
||||||
* \brief initialize the line spliter,
|
* \brief initialize the line spliter,
|
||||||
* \param file_size, size of each std::FILEs
|
* \param file_size, size of each files
|
||||||
* \param rank the current rank of the data
|
* \param rank the current rank of the data
|
||||||
* \param nsplit number of split we will divide the data into
|
* \param nsplit number of split we will divide the data into
|
||||||
*/
|
*/
|
||||||
@ -82,7 +125,7 @@ class LineSplitBase : public InputSplit {
|
|||||||
file_ptr_end_ = std::upper_bound(file_offset_.begin(),
|
file_ptr_end_ = std::upper_bound(file_offset_.begin(),
|
||||||
file_offset_.end(),
|
file_offset_.end(),
|
||||||
offset_end_) - file_offset_.begin() - 1;
|
offset_end_) - file_offset_.begin() - 1;
|
||||||
fs_ = GetFile(file_ptr_);
|
fs_ = provider_->Open(file_ptr_);
|
||||||
reader_.set_stream(fs_);
|
reader_.set_stream(fs_);
|
||||||
// try to set the starting position correctly
|
// try to set the starting position correctly
|
||||||
if (file_offset_[file_ptr_] != offset_begin_) {
|
if (file_offset_[file_ptr_] != offset_begin_) {
|
||||||
@ -94,33 +137,15 @@ class LineSplitBase : public InputSplit {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*!
|
|
||||||
* \brief get the seek stream of given file_index
|
|
||||||
* \return the corresponding seek stream at head of std::FILE
|
|
||||||
*/
|
|
||||||
virtual utils::ISeekStream *GetFile(size_t file_index) = 0;
|
|
||||||
/*!
|
|
||||||
* \brief split names given
|
|
||||||
* \param out_fname output std::FILE names
|
|
||||||
* \param uri_ the iput uri std::FILE
|
|
||||||
* \param dlm deliminetr
|
|
||||||
*/
|
|
||||||
inline static void SplitNames(std::vector<std::string> *out_fname,
|
|
||||||
const char *uri_,
|
|
||||||
const char *dlm) {
|
|
||||||
std::string uri = uri_;
|
|
||||||
char *p = std::strtok(BeginPtr(uri), dlm);
|
|
||||||
while (p != NULL) {
|
|
||||||
out_fname->push_back(std::string(p));
|
|
||||||
p = std::strtok(NULL, dlm);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
private:
|
private:
|
||||||
|
/*! \brief FileProvider */
|
||||||
|
IFileProvider *provider_;
|
||||||
/*! \brief current input stream */
|
/*! \brief current input stream */
|
||||||
utils::ISeekStream *fs_;
|
utils::ISeekStream *fs_;
|
||||||
/*! \brief std::FILE pointer of which std::FILE to read on */
|
/*! \brief file pointer of which file to read on */
|
||||||
size_t file_ptr_;
|
size_t file_ptr_;
|
||||||
/*! \brief std::FILE pointer where the end of std::FILE lies */
|
/*! \brief file pointer where the end of file lies */
|
||||||
size_t file_ptr_end_;
|
size_t file_ptr_end_;
|
||||||
/*! \brief get the current offset */
|
/*! \brief get the current offset */
|
||||||
size_t offset_curr_;
|
size_t offset_curr_;
|
||||||
@ -128,7 +153,7 @@ class LineSplitBase : public InputSplit {
|
|||||||
size_t offset_begin_;
|
size_t offset_begin_;
|
||||||
/*! \brief end of the offset */
|
/*! \brief end of the offset */
|
||||||
size_t offset_end_;
|
size_t offset_end_;
|
||||||
/*! \brief byte-offset of each std::FILE */
|
/*! \brief byte-offset of each file */
|
||||||
std::vector<size_t> file_offset_;
|
std::vector<size_t> file_offset_;
|
||||||
/*! \brief buffer reader */
|
/*! \brief buffer reader */
|
||||||
StreamBufferReader reader_;
|
StreamBufferReader reader_;
|
||||||
|
|||||||
@ -1,4 +1,10 @@
|
|||||||
# specify tensor path
|
ifneq ("$(wildcard ../config.mk)","")
|
||||||
|
config = ../config.mk
|
||||||
|
else
|
||||||
|
config = ../make/config.mk
|
||||||
|
endif
|
||||||
|
include $(config)
|
||||||
|
|
||||||
BIN = linear.rabit
|
BIN = linear.rabit
|
||||||
MOCKBIN= linear.mock
|
MOCKBIN= linear.mock
|
||||||
MPIBIN =
|
MPIBIN =
|
||||||
@ -6,10 +12,10 @@ MPIBIN =
|
|||||||
OBJ = linear.o
|
OBJ = linear.o
|
||||||
|
|
||||||
# common build script for programs
|
# common build script for programs
|
||||||
include ../make/config.mk
|
|
||||||
include ../make/common.mk
|
include ../make/common.mk
|
||||||
CFLAGS+=-fopenmp
|
CFLAGS+=-fopenmp
|
||||||
linear.o: linear.cc ../../src/*.h linear.h ../solver/*.h
|
linear.o: linear.cc ../../src/*.h linear.h ../solver/*.h
|
||||||
# dependenies here
|
# dependenies here
|
||||||
linear.rabit: linear.o lib
|
linear.rabit: linear.o lib
|
||||||
linear.mock: linear.o lib
|
linear.mock: linear.o lib
|
||||||
|
|
||||||
|
|||||||
@ -206,21 +206,22 @@ int main(int argc, char *argv[]) {
|
|||||||
rabit::Finalize();
|
rabit::Finalize();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
rabit::linear::LinearObjFunction linear;
|
rabit::linear::LinearObjFunction *linear = new rabit::linear::LinearObjFunction();
|
||||||
if (!strcmp(argv[1], "stdin")) {
|
if (!strcmp(argv[1], "stdin")) {
|
||||||
linear.LoadData(argv[1]);
|
linear->LoadData(argv[1]);
|
||||||
rabit::Init(argc, argv);
|
rabit::Init(argc, argv);
|
||||||
} else {
|
} else {
|
||||||
rabit::Init(argc, argv);
|
rabit::Init(argc, argv);
|
||||||
linear.LoadData(argv[1]);
|
linear->LoadData(argv[1]);
|
||||||
}
|
}
|
||||||
for (int i = 2; i < argc; ++i) {
|
for (int i = 2; i < argc; ++i) {
|
||||||
char name[256], val[256];
|
char name[256], val[256];
|
||||||
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
|
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
|
||||||
linear.SetParam(name, val);
|
linear->SetParam(name, val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
linear.Run();
|
linear->Run();
|
||||||
|
delete linear;
|
||||||
rabit::Finalize();
|
rabit::Finalize();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -26,10 +26,11 @@ struct LinearModel {
|
|||||||
int reserved[16];
|
int reserved[16];
|
||||||
// constructor
|
// constructor
|
||||||
ModelParam(void) {
|
ModelParam(void) {
|
||||||
|
memset(this, 0, sizeof(ModelParam));
|
||||||
base_score = 0.5f;
|
base_score = 0.5f;
|
||||||
num_feature = 0;
|
num_feature = 0;
|
||||||
loss_type = 1;
|
loss_type = 1;
|
||||||
std::memset(reserved, 0, sizeof(reserved));
|
num_feature = 0;
|
||||||
}
|
}
|
||||||
// initialize base score
|
// initialize base score
|
||||||
inline void InitBaseScore(void) {
|
inline void InitBaseScore(void) {
|
||||||
@ -119,7 +120,7 @@ struct LinearModel {
|
|||||||
}
|
}
|
||||||
fi.Read(weight, sizeof(float) * (param.num_feature + 1));
|
fi.Read(weight, sizeof(float) * (param.num_feature + 1));
|
||||||
}
|
}
|
||||||
inline void Save(rabit::IStream &fo, const float *wptr = NULL) const {
|
inline void Save(rabit::IStream &fo, const float *wptr = NULL) {
|
||||||
fo.Write(¶m, sizeof(param));
|
fo.Write(¶m, sizeof(param));
|
||||||
if (wptr == NULL) wptr = weight;
|
if (wptr == NULL) wptr = weight;
|
||||||
fo.Write(wptr, sizeof(float) * (param.num_feature + 1));
|
fo.Write(wptr, sizeof(float) * (param.num_feature + 1));
|
||||||
|
|||||||
@ -6,12 +6,13 @@ then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# put the local training file to HDFS
|
# put the local training file to HDFS
|
||||||
hadoop fs -rm -r -f $2/data
|
|
||||||
hadoop fs -rm -r -f $2/mushroom.linear.model
|
hadoop fs -rm -r -f $2/mushroom.linear.model
|
||||||
|
|
||||||
hadoop fs -mkdir $2/data
|
hadoop fs -mkdir $2/data
|
||||||
|
hadoop fs -put ../data/agaricus.txt.train $2/data
|
||||||
|
|
||||||
# submit to hadoop
|
# submit to hadoop
|
||||||
../../tracker/rabit_yarn.py -n $1 --vcores 1 linear.rabit hdfs://$2/data/agaricus.txt.train model_out=hdfs://$2/mushroom.linear.model "${*:3}"
|
../../tracker/rabit_yarn.py -n $1 --vcores 1 ./linear.rabit hdfs://$2/data/agaricus.txt.train model_out=hdfs://$2/mushroom.linear.model "${*:3}"
|
||||||
|
|
||||||
# get the final model file
|
# get the final model file
|
||||||
hadoop fs -get $2/mushroom.linear.model ./linear.model
|
hadoop fs -get $2/mushroom.linear.model ./linear.model
|
||||||
|
|||||||
@ -6,7 +6,7 @@
|
|||||||
#
|
#
|
||||||
# - copy this file to the root of rabit-learn folder
|
# - copy this file to the root of rabit-learn folder
|
||||||
# - modify the configuration you want
|
# - modify the configuration you want
|
||||||
# - type make or make -j n for parallel build
|
# - type make or make -j n on each of the folder
|
||||||
#----------------------------------------------------
|
#----------------------------------------------------
|
||||||
|
|
||||||
# choice of compiler
|
# choice of compiler
|
||||||
|
|||||||
@ -145,8 +145,9 @@ class LBFGSSolver {
|
|||||||
|
|
||||||
if (silent == 0 && rabit::GetRank() == 0) {
|
if (silent == 0 && rabit::GetRank() == 0) {
|
||||||
rabit::TrackerPrintf
|
rabit::TrackerPrintf
|
||||||
("L-BFGS solver starts, num_dim=%lu, init_objval=%g, size_memory=%lu\n",
|
("L-BFGS solver starts, num_dim=%lu, init_objval=%g, size_memory=%lu, RAM-approx=%lu\n",
|
||||||
gstate.num_dim, gstate.init_objval, gstate.size_memory);
|
gstate.num_dim, gstate.init_objval, gstate.size_memory,
|
||||||
|
gstate.MemCost() + hist.MemCost());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -176,7 +177,7 @@ class LBFGSSolver {
|
|||||||
// swap new weight
|
// swap new weight
|
||||||
std::swap(g.weight, g.grad);
|
std::swap(g.weight, g.grad);
|
||||||
// check stop condition
|
// check stop condition
|
||||||
if (gstate.num_iteration > min_lbfgs_iter) {
|
if (gstate.num_iteration > static_cast<size_t>(min_lbfgs_iter)) {
|
||||||
if (g.old_objval - g.new_objval < lbfgs_stop_tol * g.init_objval) {
|
if (g.old_objval - g.new_objval < lbfgs_stop_tol * g.init_objval) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -195,7 +196,7 @@ class LBFGSSolver {
|
|||||||
/*! \brief run optimization */
|
/*! \brief run optimization */
|
||||||
virtual void Run(void) {
|
virtual void Run(void) {
|
||||||
this->Init();
|
this->Init();
|
||||||
while (gstate.num_iteration < max_lbfgs_iter) {
|
while (gstate.num_iteration < static_cast<size_t>(max_lbfgs_iter)) {
|
||||||
if (this->UpdateOneIter()) break;
|
if (this->UpdateOneIter()) break;
|
||||||
}
|
}
|
||||||
if (silent == 0 && rabit::GetRank() == 0) {
|
if (silent == 0 && rabit::GetRank() == 0) {
|
||||||
@ -225,7 +226,7 @@ class LBFGSSolver {
|
|||||||
const size_t num_dim = gstate.num_dim;
|
const size_t num_dim = gstate.num_dim;
|
||||||
const DType *gsub = grad + range_begin_;
|
const DType *gsub = grad + range_begin_;
|
||||||
const size_t nsub = range_end_ - range_begin_;
|
const size_t nsub = range_end_ - range_begin_;
|
||||||
double vdot;
|
double vdot = 0.0;
|
||||||
if (n != 0) {
|
if (n != 0) {
|
||||||
// hist[m + n - 1] stores old gradient
|
// hist[m + n - 1] stores old gradient
|
||||||
Minus(hist[m + n - 1], gsub, hist[m + n - 1], nsub);
|
Minus(hist[m + n - 1], gsub, hist[m + n - 1], nsub);
|
||||||
@ -241,15 +242,19 @@ class LBFGSSolver {
|
|||||||
idxset.push_back(std::make_pair(m + j, 2 * m));
|
idxset.push_back(std::make_pair(m + j, 2 * m));
|
||||||
idxset.push_back(std::make_pair(m + j, m + n - 1));
|
idxset.push_back(std::make_pair(m + j, m + n - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculate dot products
|
// calculate dot products
|
||||||
std::vector<double> tmp(idxset.size());
|
std::vector<double> tmp(idxset.size());
|
||||||
for (size_t i = 0; i < tmp.size(); ++i) {
|
for (size_t i = 0; i < tmp.size(); ++i) {
|
||||||
tmp[i] = hist.CalcDot(idxset[i].first, idxset[i].second);
|
tmp[i] = hist.CalcDot(idxset[i].first, idxset[i].second);
|
||||||
}
|
}
|
||||||
|
|
||||||
rabit::Allreduce<rabit::op::Sum>(BeginPtr(tmp), tmp.size());
|
rabit::Allreduce<rabit::op::Sum>(BeginPtr(tmp), tmp.size());
|
||||||
|
|
||||||
for (size_t i = 0; i < tmp.size(); ++i) {
|
for (size_t i = 0; i < tmp.size(); ++i) {
|
||||||
gstate.DotBuf(idxset[i].first, idxset[i].second) = tmp[i];
|
gstate.DotBuf(idxset[i].first, idxset[i].second) = tmp[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
// BFGS steps, use vector-free update
|
// BFGS steps, use vector-free update
|
||||||
// parameterize vector using basis in hist
|
// parameterize vector using basis in hist
|
||||||
std::vector<double> alpha(n);
|
std::vector<double> alpha(n);
|
||||||
@ -263,7 +268,7 @@ class LBFGSSolver {
|
|||||||
}
|
}
|
||||||
alpha[j] = vsum / gstate.DotBuf(j, m + j);
|
alpha[j] = vsum / gstate.DotBuf(j, m + j);
|
||||||
delta[m + j] = delta[m + j] - alpha[j];
|
delta[m + j] = delta[m + j] - alpha[j];
|
||||||
}
|
}
|
||||||
// scale
|
// scale
|
||||||
double scale = gstate.DotBuf(n - 1, m + n - 1) /
|
double scale = gstate.DotBuf(n - 1, m + n - 1) /
|
||||||
gstate.DotBuf(m + n - 1, m + n - 1);
|
gstate.DotBuf(m + n - 1, m + n - 1);
|
||||||
@ -279,6 +284,7 @@ class LBFGSSolver {
|
|||||||
double beta = vsum / gstate.DotBuf(j, m + j);
|
double beta = vsum / gstate.DotBuf(j, m + j);
|
||||||
delta[j] = delta[j] + (alpha[j] - beta);
|
delta[j] = delta[j] + (alpha[j] - beta);
|
||||||
}
|
}
|
||||||
|
|
||||||
// set all to zero
|
// set all to zero
|
||||||
std::fill(dir, dir + num_dim, 0.0f);
|
std::fill(dir, dir + num_dim, 0.0f);
|
||||||
DType *dirsub = dir + range_begin_;
|
DType *dirsub = dir + range_begin_;
|
||||||
@ -291,10 +297,11 @@ class LBFGSSolver {
|
|||||||
}
|
}
|
||||||
FixDirL1Sign(dirsub, hist[2 * m], nsub);
|
FixDirL1Sign(dirsub, hist[2 * m], nsub);
|
||||||
vdot = -Dot(dirsub, hist[2 * m], nsub);
|
vdot = -Dot(dirsub, hist[2 * m], nsub);
|
||||||
|
|
||||||
// allreduce to get full direction
|
// allreduce to get full direction
|
||||||
rabit::Allreduce<rabit::op::Sum>(dir, num_dim);
|
rabit::Allreduce<rabit::op::Sum>(dir, num_dim);
|
||||||
rabit::Allreduce<rabit::op::Sum>(&vdot, 1);
|
rabit::Allreduce<rabit::op::Sum>(&vdot, 1);
|
||||||
} else {
|
} else {
|
||||||
SetL1Dir(dir, grad, weight, num_dim);
|
SetL1Dir(dir, grad, weight, num_dim);
|
||||||
vdot = -Dot(dir, dir, num_dim);
|
vdot = -Dot(dir, dir, num_dim);
|
||||||
}
|
}
|
||||||
@ -482,6 +489,7 @@ class LBFGSSolver {
|
|||||||
num_iteration = 0;
|
num_iteration = 0;
|
||||||
num_dim = 0;
|
num_dim = 0;
|
||||||
old_objval = 0.0;
|
old_objval = 0.0;
|
||||||
|
offset_ = 0;
|
||||||
}
|
}
|
||||||
~GlobalState(void) {
|
~GlobalState(void) {
|
||||||
if (grad != NULL) {
|
if (grad != NULL) {
|
||||||
@ -496,6 +504,10 @@ class LBFGSSolver {
|
|||||||
data.resize(n * n, 0.0);
|
data.resize(n * n, 0.0);
|
||||||
this->AllocSpace();
|
this->AllocSpace();
|
||||||
}
|
}
|
||||||
|
// memory cost
|
||||||
|
inline size_t MemCost(void) const {
|
||||||
|
return sizeof(DType) * 3 * num_dim;
|
||||||
|
}
|
||||||
inline double &DotBuf(size_t i, size_t j) {
|
inline double &DotBuf(size_t i, size_t j) {
|
||||||
if (i > j) std::swap(i, j);
|
if (i > j) std::swap(i, j);
|
||||||
return data[MapIndex(i, offset_, size_memory) * (size_memory * 2 + 1) +
|
return data[MapIndex(i, offset_, size_memory) * (size_memory * 2 + 1) +
|
||||||
@ -565,6 +577,10 @@ class LBFGSSolver {
|
|||||||
size_t n = size_memory * 2 + 1;
|
size_t n = size_memory * 2 + 1;
|
||||||
dptr_ = new DType[n * stride_];
|
dptr_ = new DType[n * stride_];
|
||||||
}
|
}
|
||||||
|
// memory cost
|
||||||
|
inline size_t MemCost(void) const {
|
||||||
|
return sizeof(DType) * (size_memory_ * 2 + 1) * stride_;
|
||||||
|
}
|
||||||
// fetch element from rolling array
|
// fetch element from rolling array
|
||||||
inline const DType *operator[](size_t i) const {
|
inline const DType *operator[](size_t i) const {
|
||||||
return dptr_ + MapIndex(i, offset_, size_memory_) * stride_;
|
return dptr_ + MapIndex(i, offset_, size_memory_) * stride_;
|
||||||
|
|||||||
@ -77,11 +77,15 @@ struct SparseMat {
|
|||||||
feat_dim += 1;
|
feat_dim += 1;
|
||||||
utils::Check(feat_dim < std::numeric_limits<index_t>::max(),
|
utils::Check(feat_dim < std::numeric_limits<index_t>::max(),
|
||||||
"feature dimension exceed limit of index_t"\
|
"feature dimension exceed limit of index_t"\
|
||||||
"consider change the index_t to unsigned long");
|
"consider change the index_t to unsigned long");
|
||||||
}
|
}
|
||||||
inline size_t NumRow(void) const {
|
inline size_t NumRow(void) const {
|
||||||
return row_ptr.size() - 1;
|
return row_ptr.size() - 1;
|
||||||
}
|
}
|
||||||
|
// memory cost
|
||||||
|
inline size_t MemCost(void) const {
|
||||||
|
return data.size() * sizeof(Entry);
|
||||||
|
}
|
||||||
// maximum feature dimension
|
// maximum feature dimension
|
||||||
size_t feat_dim;
|
size_t feat_dim;
|
||||||
std::vector<size_t> row_ptr;
|
std::vector<size_t> row_ptr;
|
||||||
|
|||||||
@ -26,6 +26,9 @@ AllreduceBase::AllreduceBase(void) {
|
|||||||
world_size = -1;
|
world_size = -1;
|
||||||
hadoop_mode = 0;
|
hadoop_mode = 0;
|
||||||
version_number = 0;
|
version_number = 0;
|
||||||
|
// 32 K items
|
||||||
|
reduce_ring_mincount = 32 << 10;
|
||||||
|
// tracker URL
|
||||||
task_id = "NULL";
|
task_id = "NULL";
|
||||||
err_link = NULL;
|
err_link = NULL;
|
||||||
this->SetParam("rabit_reduce_buffer", "256MB");
|
this->SetParam("rabit_reduce_buffer", "256MB");
|
||||||
@ -33,7 +36,8 @@ AllreduceBase::AllreduceBase(void) {
|
|||||||
env_vars.push_back("rabit_task_id");
|
env_vars.push_back("rabit_task_id");
|
||||||
env_vars.push_back("rabit_num_trial");
|
env_vars.push_back("rabit_num_trial");
|
||||||
env_vars.push_back("rabit_reduce_buffer");
|
env_vars.push_back("rabit_reduce_buffer");
|
||||||
env_vars.push_back("rabit_tracker_uri");
|
env_vars.push_back("rabit_reduce_ring_mincount");
|
||||||
|
env_vars.push_back("rabit_tracker_uri");
|
||||||
env_vars.push_back("rabit_tracker_port");
|
env_vars.push_back("rabit_tracker_port");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -116,6 +120,27 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
|
|||||||
tracker.SendStr(msg);
|
tracker.SendStr(msg);
|
||||||
tracker.Close();
|
tracker.Close();
|
||||||
}
|
}
|
||||||
|
// util to parse data with unit suffix
|
||||||
|
inline size_t ParseUnit(const char *name, const char *val) {
|
||||||
|
char unit;
|
||||||
|
uint64_t amount;
|
||||||
|
int n = sscanf(val, "%lu%c", &amount, &unit);
|
||||||
|
if (n == 2) {
|
||||||
|
switch (unit) {
|
||||||
|
case 'B': return amount;
|
||||||
|
case 'K': return amount << 10UL;
|
||||||
|
case 'M': return amount << 20UL;
|
||||||
|
case 'G': return amount << 30UL;
|
||||||
|
default: utils::Error("invalid format for %s", name); return 0;
|
||||||
|
}
|
||||||
|
} else if (n == 1) {
|
||||||
|
return amount;
|
||||||
|
} else {
|
||||||
|
utils::Error("invalid format for %s," \
|
||||||
|
"shhould be {integer}{unit}, unit can be {B, KB, MB, GB}", name);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief set parameters to the engine
|
* \brief set parameters to the engine
|
||||||
* \param name parameter name
|
* \param name parameter name
|
||||||
@ -127,21 +152,11 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
if (!strcmp(name, "rabit_task_id")) task_id = val;
|
if (!strcmp(name, "rabit_task_id")) task_id = val;
|
||||||
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
|
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
|
||||||
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val);
|
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val);
|
||||||
|
if (!strcmp(name, "rabit_reduce_ring_mincount")) {
|
||||||
|
reduce_ring_mincount = ParseUnit(name, val);
|
||||||
|
}
|
||||||
if (!strcmp(name, "rabit_reduce_buffer")) {
|
if (!strcmp(name, "rabit_reduce_buffer")) {
|
||||||
char unit;
|
reduce_buffer_size = (ParseUnit(name, val) + 7) >> 3;
|
||||||
uint64_t amount;
|
|
||||||
if (sscanf(val, "%lu%c", &amount, &unit) == 2) {
|
|
||||||
switch (unit) {
|
|
||||||
case 'B': reduce_buffer_size = (amount + 7)/ 8; break;
|
|
||||||
case 'K': reduce_buffer_size = amount << 7UL; break;
|
|
||||||
case 'M': reduce_buffer_size = amount << 17UL; break;
|
|
||||||
case 'G': reduce_buffer_size = amount << 27UL; break;
|
|
||||||
default: utils::Error("invalid format for reduce buffer");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
utils::Error("invalid format for reduce_buffer,"\
|
|
||||||
"shhould be {integer}{unit}, unit can be {B, KB, MB, GB}");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -341,6 +356,28 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
|||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer) {
|
ReduceFunction reducer) {
|
||||||
|
if (count > reduce_ring_mincount) {
|
||||||
|
return this->TryAllreduceRing(sendrecvbuf_, type_nbytes, count, reducer);
|
||||||
|
} else {
|
||||||
|
return this->TryAllreduceTree(sendrecvbuf_, type_nbytes, count, reducer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief perform in-place allreduce, on sendrecvbuf,
|
||||||
|
* this function implements tree-shape reduction
|
||||||
|
*
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param reducer reduce function
|
||||||
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
|
* \sa ReturnType
|
||||||
|
*/
|
||||||
|
AllreduceBase::ReturnType
|
||||||
|
AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
ReduceFunction reducer) {
|
||||||
RefLinkVector &links = tree_links;
|
RefLinkVector &links = tree_links;
|
||||||
if (links.size() == 0 || count == 0) return kSuccess;
|
if (links.size() == 0 || count == 0) return kSuccess;
|
||||||
// total size of message
|
// total size of message
|
||||||
@ -411,7 +448,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
|||||||
// read data from childs
|
// read data from childs
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (i != parent_index && selecter.CheckRead(links[i].sock)) {
|
if (i != parent_index && selecter.CheckRead(links[i].sock)) {
|
||||||
ReturnType ret = links[i].ReadToRingBuffer(size_up_out);
|
ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
return ReportError(&links[i], ret);
|
return ReportError(&links[i], ret);
|
||||||
}
|
}
|
||||||
@ -599,5 +636,217 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
|||||||
}
|
}
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
/*!
|
||||||
|
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||||
|
* the data provided by current node k is [slice_begin, slice_end),
|
||||||
|
* the next node's segment must start with slice_end
|
||||||
|
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||||
|
* use a ring based algorithm
|
||||||
|
*
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||||
|
* \param total_size total size of data to be gathered
|
||||||
|
* \param slice_begin beginning of the current slice
|
||||||
|
* \param slice_end end of the current slice
|
||||||
|
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||||
|
*/
|
||||||
|
AllreduceBase::ReturnType
|
||||||
|
AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
|
||||||
|
size_t slice_begin,
|
||||||
|
size_t slice_end,
|
||||||
|
size_t size_prev_slice) {
|
||||||
|
// read from next link and send to prev one
|
||||||
|
LinkRecord &prev = *ring_prev, &next = *ring_next;
|
||||||
|
// need to reply on special rank structure
|
||||||
|
utils::Assert(next.rank == (rank + 1) % world_size &&
|
||||||
|
rank == (prev.rank + 1) % world_size,
|
||||||
|
"need to assume rank structure");
|
||||||
|
// send recv buffer
|
||||||
|
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||||
|
const size_t stop_read = total_size + slice_begin;
|
||||||
|
const size_t stop_write = total_size + slice_begin - size_prev_slice;
|
||||||
|
size_t write_ptr = slice_begin;
|
||||||
|
size_t read_ptr = slice_end;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
// select helper
|
||||||
|
bool finished = true;
|
||||||
|
utils::SelectHelper selecter;
|
||||||
|
if (read_ptr != stop_read) {
|
||||||
|
selecter.WatchRead(next.sock);
|
||||||
|
finished = false;
|
||||||
|
}
|
||||||
|
if (write_ptr != stop_write) {
|
||||||
|
if (write_ptr < read_ptr) {
|
||||||
|
selecter.WatchWrite(prev.sock);
|
||||||
|
}
|
||||||
|
finished = false;
|
||||||
|
}
|
||||||
|
if (finished) break;
|
||||||
|
selecter.Select();
|
||||||
|
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
|
||||||
|
size_t size = stop_read - read_ptr;
|
||||||
|
size_t start = read_ptr % total_size;
|
||||||
|
if (start + size > total_size) {
|
||||||
|
size = total_size - start;
|
||||||
|
}
|
||||||
|
ssize_t len = next.sock.Recv(sendrecvbuf + start, size);
|
||||||
|
if (len != -1) {
|
||||||
|
read_ptr += static_cast<size_t>(len);
|
||||||
|
} else {
|
||||||
|
ReturnType ret = Errno2Return(errno);
|
||||||
|
if (ret != kSuccess) return ReportError(&next, ret);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (write_ptr < read_ptr && write_ptr != stop_write) {
|
||||||
|
size_t size = std::min(read_ptr, stop_write) - write_ptr;
|
||||||
|
size_t start = write_ptr % total_size;
|
||||||
|
if (start + size > total_size) {
|
||||||
|
size = total_size - start;
|
||||||
|
}
|
||||||
|
ssize_t len = prev.sock.Send(sendrecvbuf + start, size);
|
||||||
|
if (len != -1) {
|
||||||
|
write_ptr += static_cast<size_t>(len);
|
||||||
|
} else {
|
||||||
|
ReturnType ret = Errno2Return(errno);
|
||||||
|
if (ret != kSuccess) return ReportError(&prev, ret);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail,
|
||||||
|
* and will return the cause of failure
|
||||||
|
*
|
||||||
|
* Ring-based algorithm
|
||||||
|
*
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param reducer reduce function
|
||||||
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
|
* \sa ReturnType, TryAllreduce
|
||||||
|
*/
|
||||||
|
AllreduceBase::ReturnType
|
||||||
|
AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
ReduceFunction reducer) {
|
||||||
|
// read from next link and send to prev one
|
||||||
|
LinkRecord &prev = *ring_prev, &next = *ring_next;
|
||||||
|
// need to reply on special rank structure
|
||||||
|
utils::Assert(next.rank == (rank + 1) % world_size &&
|
||||||
|
rank == (prev.rank + 1) % world_size,
|
||||||
|
"need to assume rank structure");
|
||||||
|
// total size of message
|
||||||
|
const size_t total_size = type_nbytes * count;
|
||||||
|
size_t n = static_cast<size_t>(world_size);
|
||||||
|
size_t step = (count + n - 1) / n;
|
||||||
|
size_t r = static_cast<size_t>(next.rank);
|
||||||
|
size_t write_ptr = std::min(r * step, count) * type_nbytes;
|
||||||
|
size_t read_ptr = std::min((r + 1) * step, count) * type_nbytes;
|
||||||
|
size_t reduce_ptr = read_ptr;
|
||||||
|
// send recv buffer
|
||||||
|
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||||
|
// position to stop reading
|
||||||
|
const size_t stop_read = total_size + write_ptr;
|
||||||
|
// position to stop writing
|
||||||
|
size_t stop_write = total_size + std::min(rank * step, count) * type_nbytes;
|
||||||
|
if (stop_write > stop_read) {
|
||||||
|
stop_write -= total_size;
|
||||||
|
utils::Assert(write_ptr <= stop_write, "write ptr boundary check");
|
||||||
|
}
|
||||||
|
// use ring buffer in next position
|
||||||
|
next.InitBuffer(type_nbytes, step, reduce_buffer_size);
|
||||||
|
// set size_read to read pointer for ring buffer to work properly
|
||||||
|
next.size_read = read_ptr;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
// select helper
|
||||||
|
bool finished = true;
|
||||||
|
utils::SelectHelper selecter;
|
||||||
|
if (read_ptr != stop_read) {
|
||||||
|
selecter.WatchRead(next.sock);
|
||||||
|
finished = false;
|
||||||
|
}
|
||||||
|
if (write_ptr != stop_write) {
|
||||||
|
if (write_ptr < reduce_ptr) {
|
||||||
|
selecter.WatchWrite(prev.sock);
|
||||||
|
}
|
||||||
|
finished = false;
|
||||||
|
}
|
||||||
|
if (finished) break;
|
||||||
|
selecter.Select();
|
||||||
|
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
|
||||||
|
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
return ReportError(&next, ret);
|
||||||
|
}
|
||||||
|
// sync the rate
|
||||||
|
read_ptr = next.size_read;
|
||||||
|
utils::Assert(read_ptr <= stop_read, "[%d] read_ptr boundary check", rank);
|
||||||
|
const size_t buffer_size = next.buffer_size;
|
||||||
|
size_t max_reduce = (read_ptr / type_nbytes) * type_nbytes;
|
||||||
|
while (reduce_ptr < max_reduce) {
|
||||||
|
size_t bstart = reduce_ptr % buffer_size;
|
||||||
|
size_t nread = std::min(buffer_size - bstart,
|
||||||
|
max_reduce - reduce_ptr);
|
||||||
|
size_t rstart = reduce_ptr % total_size;
|
||||||
|
nread = std::min(nread, total_size - rstart);
|
||||||
|
reducer(next.buffer_head + bstart,
|
||||||
|
sendrecvbuf + rstart,
|
||||||
|
static_cast<int>(nread / type_nbytes),
|
||||||
|
MPI::Datatype(type_nbytes));
|
||||||
|
reduce_ptr += nread;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (write_ptr < reduce_ptr && write_ptr != stop_write) {
|
||||||
|
size_t size = std::min(reduce_ptr, stop_write) - write_ptr;
|
||||||
|
size_t start = write_ptr % total_size;
|
||||||
|
if (start + size > total_size) {
|
||||||
|
size = total_size - start;
|
||||||
|
}
|
||||||
|
ssize_t len = prev.sock.Send(sendrecvbuf + start, size);
|
||||||
|
if (len != -1) {
|
||||||
|
write_ptr += static_cast<size_t>(len);
|
||||||
|
} else {
|
||||||
|
ReturnType ret = Errno2Return(errno);
|
||||||
|
if (ret != kSuccess) return ReportError(&prev, ret);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief perform in-place allreduce, on sendrecvbuf
|
||||||
|
* use a ring based algorithm
|
||||||
|
*
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param reducer reduce function
|
||||||
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
|
* \sa ReturnType
|
||||||
|
*/
|
||||||
|
AllreduceBase::ReturnType
|
||||||
|
AllreduceBase::TryAllreduceRing(void *sendrecvbuf_,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
ReduceFunction reducer) {
|
||||||
|
ReturnType ret = TryReduceScatterRing(sendrecvbuf_, type_nbytes, count, reducer);
|
||||||
|
if (ret != kSuccess) return ret;
|
||||||
|
size_t n = static_cast<size_t>(world_size);
|
||||||
|
size_t step = (count + n - 1) / n;
|
||||||
|
size_t begin = std::min(rank * step, count) * type_nbytes;
|
||||||
|
size_t end = std::min((rank + 1) * step, count) * type_nbytes;
|
||||||
|
// previous rank
|
||||||
|
int prank = ring_prev->rank;
|
||||||
|
// get rank of previous
|
||||||
|
return TryAllgatherRing
|
||||||
|
(sendrecvbuf_, type_nbytes * count,
|
||||||
|
begin, end,
|
||||||
|
(std::min((prank + 1) * step, count) -
|
||||||
|
std::min(prank * step, count)) * type_nbytes);
|
||||||
|
}
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
@ -278,15 +278,19 @@ class AllreduceBase : public IEngine {
|
|||||||
* \brief read data into ring-buffer, with care not to existing useful override data
|
* \brief read data into ring-buffer, with care not to existing useful override data
|
||||||
* position after protect_start
|
* position after protect_start
|
||||||
* \param protect_start all data start from protect_start is still needed in buffer
|
* \param protect_start all data start from protect_start is still needed in buffer
|
||||||
* read shall not override this
|
* read shall not override this
|
||||||
|
* \param max_size_read maximum logical amount we can read, size_read cannot exceed this value
|
||||||
* \return the type of reading
|
* \return the type of reading
|
||||||
*/
|
*/
|
||||||
inline ReturnType ReadToRingBuffer(size_t protect_start) {
|
inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) {
|
||||||
utils::Assert(buffer_head != NULL, "ReadToRingBuffer: buffer not allocated");
|
utils::Assert(buffer_head != NULL, "ReadToRingBuffer: buffer not allocated");
|
||||||
|
utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check");
|
||||||
size_t ngap = size_read - protect_start;
|
size_t ngap = size_read - protect_start;
|
||||||
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
|
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
|
||||||
size_t offset = size_read % buffer_size;
|
size_t offset = size_read % buffer_size;
|
||||||
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
|
size_t nmax = max_size_read - size_read;
|
||||||
|
nmax = std::min(nmax, buffer_size - ngap);
|
||||||
|
nmax = std::min(nmax, buffer_size - offset);
|
||||||
if (nmax == 0) return kSuccess;
|
if (nmax == 0) return kSuccess;
|
||||||
ssize_t len = sock.Recv(buffer_head + offset, nmax);
|
ssize_t len = sock.Recv(buffer_head + offset, nmax);
|
||||||
// length equals 0, remote disconnected
|
// length equals 0, remote disconnected
|
||||||
@ -380,13 +384,79 @@ class AllreduceBase : public IEngine {
|
|||||||
ReduceFunction reducer);
|
ReduceFunction reducer);
|
||||||
/*!
|
/*!
|
||||||
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
|
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
|
||||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||||
* \param size the size of the data to be broadcasted
|
* \param size the size of the data to be broadcasted
|
||||||
* \param root the root worker id to broadcast the data
|
* \param root the root worker id to broadcast the data
|
||||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
* \sa ReturnType
|
* \sa ReturnType
|
||||||
*/
|
*/
|
||||||
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
|
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
|
||||||
|
/*!
|
||||||
|
* \brief perform in-place allreduce, on sendrecvbuf,
|
||||||
|
* this function implements tree-shape reduction
|
||||||
|
*
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param reducer reduce function
|
||||||
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
|
* \sa ReturnType
|
||||||
|
*/
|
||||||
|
ReturnType TryAllreduceTree(void *sendrecvbuf_,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
ReduceFunction reducer);
|
||||||
|
/*!
|
||||||
|
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||||
|
* the data provided by current node k is [slice_begin, slice_end),
|
||||||
|
* the next node's segment must start with slice_end
|
||||||
|
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||||
|
* use a ring based algorithm
|
||||||
|
*
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||||
|
* \param total_size total size of data to be gathered
|
||||||
|
* \param slice_begin beginning of the current slice
|
||||||
|
* \param slice_end end of the current slice
|
||||||
|
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||||
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
|
* \sa ReturnType
|
||||||
|
*/
|
||||||
|
ReturnType TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
|
||||||
|
size_t slice_begin, size_t slice_end,
|
||||||
|
size_t size_prev_slice);
|
||||||
|
/*!
|
||||||
|
* \brief perform in-place allreduce, reduce on the sendrecvbuf,
|
||||||
|
*
|
||||||
|
* after the function, node k get k-th segment of the reduction result
|
||||||
|
* the k-th segment is defined by [k * step, min((k + 1) * step,count) )
|
||||||
|
* where step = ceil(count / world_size)
|
||||||
|
*
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param reducer reduce function
|
||||||
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
|
* \sa ReturnType, TryAllreduce
|
||||||
|
*/
|
||||||
|
ReturnType TryReduceScatterRing(void *sendrecvbuf_,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
ReduceFunction reducer);
|
||||||
|
/*!
|
||||||
|
* \brief perform in-place allreduce, on sendrecvbuf
|
||||||
|
* use a ring based algorithm, reduce-scatter + allgather
|
||||||
|
*
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param reducer reduce function
|
||||||
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
|
* \sa ReturnType
|
||||||
|
*/
|
||||||
|
ReturnType TryAllreduceRing(void *sendrecvbuf_,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
ReduceFunction reducer);
|
||||||
/*!
|
/*!
|
||||||
* \brief function used to report error when a link goes wrong
|
* \brief function used to report error when a link goes wrong
|
||||||
* \param link the pointer to the link who causes the error
|
* \param link the pointer to the link who causes the error
|
||||||
@ -432,6 +502,10 @@ class AllreduceBase : public IEngine {
|
|||||||
int slave_port, nport_trial;
|
int slave_port, nport_trial;
|
||||||
// reduce buffer size
|
// reduce buffer size
|
||||||
size_t reduce_buffer_size;
|
size_t reduce_buffer_size;
|
||||||
|
// reduction method
|
||||||
|
int reduce_method;
|
||||||
|
// mininum count of cells to use ring based method
|
||||||
|
size_t reduce_ring_mincount;
|
||||||
// current rank
|
// current rank
|
||||||
int rank;
|
int rank;
|
||||||
// world size
|
// world size
|
||||||
|
|||||||
@ -81,18 +81,18 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
ComboSerializer com(global_model, local_model);
|
ComboSerializer com(global_model, local_model);
|
||||||
AllreduceRobust::CheckPoint(&dum, &com);
|
AllreduceRobust::CheckPoint(&dum, &com);
|
||||||
}
|
}
|
||||||
tsum_allreduce = 0.0;
|
|
||||||
time_checkpoint = utils::GetTime();
|
time_checkpoint = utils::GetTime();
|
||||||
double tcost = utils::GetTime() - tstart;
|
double tcost = utils::GetTime() - tstart;
|
||||||
if (report_stats != 0 && rank == 0) {
|
if (report_stats != 0 && rank == 0) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "[v" << version_number << "] global_size=" << global_checkpoint.length()
|
ss << "[v" << version_number << "] global_size=" << global_checkpoint.length()
|
||||||
<< "local_size=" << local_chkpt[local_chkpt_version].length()
|
<< ",local_size=" << (local_chkpt[0].length() + local_chkpt[1].length())
|
||||||
<< "check_tcost="<< tcost <<" sec,"
|
<< ",check_tcost="<< tcost <<" sec"
|
||||||
<< "allreduce_tcost=" << tsum_allreduce << " sec,"
|
<< ",allreduce_tcost=" << tsum_allreduce << " sec"
|
||||||
<< "between_chpt=" << tbet_chkpt << "sec\n";
|
<< ",between_chpt=" << tbet_chkpt << "sec\n";
|
||||||
this->TrackerPrint(ss.str());
|
this->TrackerPrint(ss.str());
|
||||||
}
|
}
|
||||||
|
tsum_allreduce = 0.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||||
|
|||||||
@ -677,7 +677,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
|||||||
if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
|
if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
|
||||||
}
|
}
|
||||||
utils::Assert(min_write <= links[pid].size_read, "boundary check");
|
utils::Assert(min_write <= links[pid].size_read, "boundary check");
|
||||||
ReturnType ret = links[pid].ReadToRingBuffer(min_write);
|
ReturnType ret = links[pid].ReadToRingBuffer(min_write, size);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
return ReportError(&links[pid], ret);
|
return ReportError(&links[pid], ret);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -287,7 +287,6 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
if (seqno_.size() == 0) return -1;
|
if (seqno_.size() == 0) return -1;
|
||||||
return seqno_.back();
|
return seqno_.back();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// sequence number of each
|
// sequence number of each
|
||||||
std::vector<int> seqno_;
|
std::vector<int> seqno_;
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
export CC = gcc
|
export CC = gcc
|
||||||
export CXX = g++
|
export CXX = g++
|
||||||
export MPICXX = mpicxx
|
export MPICXX = mpicxx
|
||||||
export LDFLAGS= -pthread -lm -lrt -L../lib
|
export LDFLAGS= -L../lib -pthread -lm -lrt
|
||||||
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../include -std=c++11
|
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../include -std=c++11
|
||||||
|
|
||||||
# specify tensor path
|
# specify tensor path
|
||||||
@ -29,7 +29,7 @@ local_recover: local_recover.o $(RABIT_OBJ)
|
|||||||
lazy_recover: lazy_recover.o $(RABIT_OBJ)
|
lazy_recover: lazy_recover.o $(RABIT_OBJ)
|
||||||
|
|
||||||
$(BIN) :
|
$(BIN) :
|
||||||
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit_mock
|
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) -lrabit_mock $(LDFLAGS)
|
||||||
|
|
||||||
$(OBJ) :
|
$(OBJ) :
|
||||||
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
||||||
|
|||||||
@ -23,4 +23,7 @@ lazy_recover_10_10k_die_hard:
|
|||||||
../tracker/rabit_demo.py -n 10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0
|
../tracker/rabit_demo.py -n 10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0
|
||||||
|
|
||||||
lazy_recover_10_10k_die_same:
|
lazy_recover_10_10k_die_same:
|
||||||
../tracker/rabit_demo.py -n 10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0
|
../tracker/rabit_demo.py -n 10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0
|
||||||
|
|
||||||
|
ringallreduce_10_10k:
|
||||||
|
../tracker/rabit_demo.py -v 1 -n 10 model_recover 100 rabit_reduce_ring_mincount=10
|
||||||
|
|||||||
@ -9,4 +9,4 @@ the example guidelines are in the script themselfs
|
|||||||
* Yarn (Hadoop): [rabit_yarn.py](rabit_yarn.py)
|
* Yarn (Hadoop): [rabit_yarn.py](rabit_yarn.py)
|
||||||
- It is also possible to submit via hadoop streaming with rabit_hadoop_streaming.py
|
- It is also possible to submit via hadoop streaming with rabit_hadoop_streaming.py
|
||||||
- However, it is higly recommended to use rabit_yarn.py because this will allocate resources more precisely and fits machine learning scenarios
|
- However, it is higly recommended to use rabit_yarn.py because this will allocate resources more precisely and fits machine learning scenarios
|
||||||
|
* Sun Grid engine: [rabit_sge.py](rabit_sge.py)
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
#!/usr/bin/python
|
#!/usr/bin/env python
|
||||||
"""
|
"""
|
||||||
This is the demo submission script of rabit, it is created to
|
This is the demo submission script of rabit for submitting jobs in local machine
|
||||||
submit rabit jobs using hadoop streaming
|
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
@ -43,7 +42,7 @@ def exec_cmd(cmd, taskid, worker_env):
|
|||||||
if cmd[0].find('/') == -1 and os.path.exists(cmd[0]) and os.name != 'nt':
|
if cmd[0].find('/') == -1 and os.path.exists(cmd[0]) and os.name != 'nt':
|
||||||
cmd[0] = './' + cmd[0]
|
cmd[0] = './' + cmd[0]
|
||||||
cmd = ' '.join(cmd)
|
cmd = ' '.join(cmd)
|
||||||
env = {}
|
env = os.environ.copy()
|
||||||
for k, v in worker_env.items():
|
for k, v in worker_env.items():
|
||||||
env[k] = str(v)
|
env[k] = str(v)
|
||||||
env['rabit_task_id'] = str(taskid)
|
env['rabit_task_id'] = str(taskid)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
#!/usr/bin/python
|
#!/usr/bin/env python
|
||||||
"""
|
"""
|
||||||
Deprecated
|
Deprecated
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
#!/usr/bin/python
|
#!/usr/bin/env python
|
||||||
"""
|
"""
|
||||||
This is the demo submission script of rabit, it is created to
|
Submission script to submit rabit jobs using MPI
|
||||||
submit rabit jobs using hadoop streaming
|
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
69
subtree/rabit/tracker/rabit_sge.py
Executable file
69
subtree/rabit/tracker/rabit_sge.py
Executable file
@ -0,0 +1,69 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Submit rabit jobs to Sun Grid Engine
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import rabit_tracker as tracker
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Rabit script to submit rabit job using MPI')
|
||||||
|
parser.add_argument('-n', '--nworker', required=True, type=int,
|
||||||
|
help = 'number of worker proccess to be launched')
|
||||||
|
parser.add_argument('-q', '--queue', default='default', type=str,
|
||||||
|
help = 'the queue we want to submit the job to')
|
||||||
|
parser.add_argument('-hip', '--host_ip', default='auto', type=str,
|
||||||
|
help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine')
|
||||||
|
parser.add_argument('--vcores', default = 1, type=int,
|
||||||
|
help = 'number of vcpores to request in each mapper, set it if each rabit job is multi-threaded')
|
||||||
|
parser.add_argument('--jobname', default='auto', help = 'customize jobname in tracker')
|
||||||
|
parser.add_argument('--logdir', default='auto', help = 'customize the directory to place the logs')
|
||||||
|
parser.add_argument('-v', '--verbose', default=0, choices=[0, 1], type=int,
|
||||||
|
help = 'print more messages into the console')
|
||||||
|
parser.add_argument('command', nargs='+',
|
||||||
|
help = 'command for rabit program')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.jobname == 'auto':
|
||||||
|
args.jobname = ('rabit%d.' % args.nworker) + args.command[0].split('/')[-1];
|
||||||
|
if args.logdir == 'auto':
|
||||||
|
args.logdir = args.jobname + '.log'
|
||||||
|
|
||||||
|
if os.path.exists(args.logdir):
|
||||||
|
if not os.path.isdir(args.logdir):
|
||||||
|
raise RuntimeError('specified logdir %s is a file instead of directory' % args.logdir)
|
||||||
|
else:
|
||||||
|
os.mkdir(args.logdir)
|
||||||
|
|
||||||
|
runscript = '%s/runrabit.sh' % args.logdir
|
||||||
|
fo = open(runscript, 'w')
|
||||||
|
fo.write('\"$@\"\n')
|
||||||
|
fo.close()
|
||||||
|
#
|
||||||
|
# submission script using MPI
|
||||||
|
#
|
||||||
|
def sge_submit(nslave, worker_args, worker_envs):
|
||||||
|
"""
|
||||||
|
customized submit script, that submit nslave jobs, each must contain args as parameter
|
||||||
|
note this can be a lambda function containing additional parameters in input
|
||||||
|
Parameters
|
||||||
|
nslave number of slave process to start up
|
||||||
|
args arguments to launch each job
|
||||||
|
this usually includes the parameters of master_uri and parameters passed into submit
|
||||||
|
"""
|
||||||
|
env_arg = ','.join(['%s=\"%s\"' % (k, str(v)) for k, v in worker_envs.items()])
|
||||||
|
cmd = 'qsub -cwd -t 1-%d -S /bin/bash' % nslave
|
||||||
|
if args.queue != 'default':
|
||||||
|
cmd += '-q %s' % args.queue
|
||||||
|
cmd += ' -N %s ' % args.jobname
|
||||||
|
cmd += ' -e %s -o %s' % (args.logdir, args.logdir)
|
||||||
|
cmd += ' -pe orte %d' % (args.vcores)
|
||||||
|
cmd += ' -v %s,PATH=${PATH}:.' % env_arg
|
||||||
|
cmd += ' %s %s' % (runscript, ' '.join(args.command + worker_args))
|
||||||
|
print cmd
|
||||||
|
subprocess.check_call(cmd, shell = True)
|
||||||
|
print 'Waiting for the jobs to get up...'
|
||||||
|
|
||||||
|
# call submit, with nslave, the commands to run each job and submit function
|
||||||
|
tracker.submit(args.nworker, [], fun_submit = sge_submit, verbose = args.verbose)
|
||||||
@ -13,6 +13,7 @@ import socket
|
|||||||
import struct
|
import struct
|
||||||
import subprocess
|
import subprocess
|
||||||
import random
|
import random
|
||||||
|
import time
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -188,6 +189,7 @@ class Tracker:
|
|||||||
vlst.reverse()
|
vlst.reverse()
|
||||||
rlst += vlst
|
rlst += vlst
|
||||||
return rlst
|
return rlst
|
||||||
|
|
||||||
def get_ring(self, tree_map, parent_map):
|
def get_ring(self, tree_map, parent_map):
|
||||||
"""
|
"""
|
||||||
get a ring connection used to recover local data
|
get a ring connection used to recover local data
|
||||||
@ -202,14 +204,44 @@ class Tracker:
|
|||||||
rnext = (r + 1) % nslave
|
rnext = (r + 1) % nslave
|
||||||
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
|
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
|
||||||
return ring_map
|
return ring_map
|
||||||
|
|
||||||
|
def get_link_map(self, nslave):
|
||||||
|
"""
|
||||||
|
get the link map, this is a bit hacky, call for better algorithm
|
||||||
|
to place similar nodes together
|
||||||
|
"""
|
||||||
|
tree_map, parent_map = self.get_tree(nslave)
|
||||||
|
ring_map = self.get_ring(tree_map, parent_map)
|
||||||
|
rmap = {0 : 0}
|
||||||
|
k = 0
|
||||||
|
for i in range(nslave - 1):
|
||||||
|
k = ring_map[k][1]
|
||||||
|
rmap[k] = i + 1
|
||||||
|
|
||||||
|
ring_map_ = {}
|
||||||
|
tree_map_ = {}
|
||||||
|
parent_map_ ={}
|
||||||
|
for k, v in ring_map.items():
|
||||||
|
ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]])
|
||||||
|
for k, v in tree_map.items():
|
||||||
|
tree_map_[rmap[k]] = [rmap[x] for x in v]
|
||||||
|
for k, v in parent_map.items():
|
||||||
|
if k != 0:
|
||||||
|
parent_map_[rmap[k]] = rmap[v]
|
||||||
|
else:
|
||||||
|
parent_map_[rmap[k]] = -1
|
||||||
|
return tree_map_, parent_map_, ring_map_
|
||||||
|
|
||||||
def handle_print(self,slave, msg):
|
def handle_print(self,slave, msg):
|
||||||
sys.stdout.write(msg)
|
sys.stdout.write(msg)
|
||||||
|
|
||||||
def log_print(self, msg, level):
|
def log_print(self, msg, level):
|
||||||
if level == 1:
|
if level == 1:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
sys.stderr.write(msg + '\n')
|
sys.stderr.write(msg + '\n')
|
||||||
else:
|
else:
|
||||||
sys.stderr.write(msg + '\n')
|
sys.stderr.write(msg + '\n')
|
||||||
|
|
||||||
def accept_slaves(self, nslave):
|
def accept_slaves(self, nslave):
|
||||||
# set of nodes that finishs the job
|
# set of nodes that finishs the job
|
||||||
shutdown = {}
|
shutdown = {}
|
||||||
@ -241,31 +273,40 @@ class Tracker:
|
|||||||
assert s.cmd == 'start'
|
assert s.cmd == 'start'
|
||||||
if s.world_size > 0:
|
if s.world_size > 0:
|
||||||
nslave = s.world_size
|
nslave = s.world_size
|
||||||
tree_map, parent_map = self.get_tree(nslave)
|
tree_map, parent_map, ring_map = self.get_link_map(nslave)
|
||||||
ring_map = self.get_ring(tree_map, parent_map)
|
|
||||||
# set of nodes that is pending for getting up
|
# set of nodes that is pending for getting up
|
||||||
todo_nodes = range(nslave)
|
todo_nodes = range(nslave)
|
||||||
random.shuffle(todo_nodes)
|
|
||||||
else:
|
else:
|
||||||
assert s.world_size == -1 or s.world_size == nslave
|
assert s.world_size == -1 or s.world_size == nslave
|
||||||
if s.cmd == 'recover':
|
if s.cmd == 'recover':
|
||||||
assert s.rank >= 0
|
assert s.rank >= 0
|
||||||
|
|
||||||
rank = s.decide_rank(job_map)
|
rank = s.decide_rank(job_map)
|
||||||
|
# batch assignment of ranks
|
||||||
if rank == -1:
|
if rank == -1:
|
||||||
assert len(todo_nodes) != 0
|
assert len(todo_nodes) != 0
|
||||||
rank = todo_nodes.pop(0)
|
pending.append(s)
|
||||||
if s.jobid != 'NULL':
|
if len(pending) == len(todo_nodes):
|
||||||
job_map[s.jobid] = rank
|
pending.sort(key = lambda x : x.host)
|
||||||
|
for s in pending:
|
||||||
|
rank = todo_nodes.pop(0)
|
||||||
|
if s.jobid != 'NULL':
|
||||||
|
job_map[s.jobid] = rank
|
||||||
|
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||||
|
if s.wait_accept > 0:
|
||||||
|
wait_conn[rank] = s
|
||||||
|
self.log_print('Recieve %s signal from %s; assign rank %d' % (s.cmd, s.host, s.rank), 1)
|
||||||
if len(todo_nodes) == 0:
|
if len(todo_nodes) == 0:
|
||||||
self.log_print('@tracker All of %d nodes getting started' % nslave, 2)
|
self.log_print('@tracker All of %d nodes getting started' % nslave, 2)
|
||||||
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
self.start_time = time.time()
|
||||||
if s.cmd != 'start':
|
|
||||||
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank), 1)
|
|
||||||
else:
|
else:
|
||||||
self.log_print('Recieve %s signal from %s; assign rank %d' % (s.cmd, s.host, s.rank), 1)
|
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||||
if s.wait_accept > 0:
|
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank), 1)
|
||||||
wait_conn[rank] = s
|
if s.wait_accept > 0:
|
||||||
|
wait_conn[rank] = s
|
||||||
self.log_print('@tracker All nodes finishes job', 2)
|
self.log_print('@tracker All nodes finishes job', 2)
|
||||||
|
self.end_time = time.time()
|
||||||
|
self.log_print('@tracker %s secs between node start and job finish' % str(self.end_time - self.start_time), 2)
|
||||||
|
|
||||||
def submit(nslave, args, fun_submit, verbose, hostIP = 'auto'):
|
def submit(nslave, args, fun_submit, verbose, hostIP = 'auto'):
|
||||||
master = Tracker(verbose = verbose, hostIP = hostIP)
|
master = Tracker(verbose = verbose, hostIP = hostIP)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
#!/usr/bin/python
|
#!/usr/bin/env python
|
||||||
"""
|
"""
|
||||||
This is a script to submit rabit job via Yarn
|
This is a script to submit rabit job via Yarn
|
||||||
rabit will run as a Yarn application
|
rabit will run as a Yarn application
|
||||||
@ -13,6 +13,7 @@ import rabit_tracker as tracker
|
|||||||
|
|
||||||
WRAPPER_PATH = os.path.dirname(__file__) + '/../wrapper'
|
WRAPPER_PATH = os.path.dirname(__file__) + '/../wrapper'
|
||||||
YARN_JAR_PATH = os.path.dirname(__file__) + '/../yarn/rabit-yarn.jar'
|
YARN_JAR_PATH = os.path.dirname(__file__) + '/../yarn/rabit-yarn.jar'
|
||||||
|
YARN_BOOT_PY = os.path.dirname(__file__) + '/../yarn/run_hdfs_prog.py'
|
||||||
|
|
||||||
if not os.path.exists(YARN_JAR_PATH):
|
if not os.path.exists(YARN_JAR_PATH):
|
||||||
warnings.warn("cannot find \"%s\", I will try to run build" % YARN_JAR_PATH)
|
warnings.warn("cannot find \"%s\", I will try to run build" % YARN_JAR_PATH)
|
||||||
@ -21,7 +22,7 @@ if not os.path.exists(YARN_JAR_PATH):
|
|||||||
subprocess.check_call(cmd, shell = True, env = os.environ)
|
subprocess.check_call(cmd, shell = True, env = os.environ)
|
||||||
assert os.path.exists(YARN_JAR_PATH), "failed to build rabit-yarn.jar, try it manually"
|
assert os.path.exists(YARN_JAR_PATH), "failed to build rabit-yarn.jar, try it manually"
|
||||||
|
|
||||||
hadoop_binary = 'hadoop'
|
hadoop_binary = None
|
||||||
# code
|
# code
|
||||||
hadoop_home = os.getenv('HADOOP_HOME')
|
hadoop_home = os.getenv('HADOOP_HOME')
|
||||||
|
|
||||||
@ -38,6 +39,8 @@ parser.add_argument('-hip', '--host_ip', default='auto', type=str,
|
|||||||
help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine')
|
help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine')
|
||||||
parser.add_argument('-v', '--verbose', default=0, choices=[0, 1], type=int,
|
parser.add_argument('-v', '--verbose', default=0, choices=[0, 1], type=int,
|
||||||
help = 'print more messages into the console')
|
help = 'print more messages into the console')
|
||||||
|
parser.add_argument('-q', '--queue', default='default', type=str,
|
||||||
|
help = 'the queue we want to submit the job to')
|
||||||
parser.add_argument('-ac', '--auto_file_cache', default=1, choices=[0, 1], type=int,
|
parser.add_argument('-ac', '--auto_file_cache', default=1, choices=[0, 1], type=int,
|
||||||
help = 'whether automatically cache the files in the command to hadoop localfile, this is on by default')
|
help = 'whether automatically cache the files in the command to hadoop localfile, this is on by default')
|
||||||
parser.add_argument('-f', '--files', default = [], action='append',
|
parser.add_argument('-f', '--files', default = [], action='append',
|
||||||
@ -56,6 +59,11 @@ parser.add_argument('-mem', '--memory_mb', default=1024, type=int,
|
|||||||
help = 'maximum memory used by the process. Guide: set it large (near mapred.cluster.max.map.memory.mb)'\
|
help = 'maximum memory used by the process. Guide: set it large (near mapred.cluster.max.map.memory.mb)'\
|
||||||
'if you are running multi-threading rabit,'\
|
'if you are running multi-threading rabit,'\
|
||||||
'so that each node can occupy all the mapper slots in a machine for maximum performance')
|
'so that each node can occupy all the mapper slots in a machine for maximum performance')
|
||||||
|
parser.add_argument('--libhdfs-opts', default='-Xmx128m', type=str,
|
||||||
|
help = 'setting to be passed to libhdfs')
|
||||||
|
parser.add_argument('--name-node', default='default', type=str,
|
||||||
|
help = 'the namenode address of hdfs, libhdfs should connect to, normally leave it as default')
|
||||||
|
|
||||||
parser.add_argument('command', nargs='+',
|
parser.add_argument('command', nargs='+',
|
||||||
help = 'command for rabit program')
|
help = 'command for rabit program')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -87,7 +95,7 @@ if hadoop_version < 2:
|
|||||||
print 'Current Hadoop Version is %s, rabit_yarn will need Yarn(Hadoop 2.0)' % out[1]
|
print 'Current Hadoop Version is %s, rabit_yarn will need Yarn(Hadoop 2.0)' % out[1]
|
||||||
|
|
||||||
def submit_yarn(nworker, worker_args, worker_env):
|
def submit_yarn(nworker, worker_args, worker_env):
|
||||||
fset = set([YARN_JAR_PATH])
|
fset = set([YARN_JAR_PATH, YARN_BOOT_PY])
|
||||||
if args.auto_file_cache != 0:
|
if args.auto_file_cache != 0:
|
||||||
for i in range(len(args.command)):
|
for i in range(len(args.command)):
|
||||||
f = args.command[i]
|
f = args.command[i]
|
||||||
@ -96,7 +104,7 @@ def submit_yarn(nworker, worker_args, worker_env):
|
|||||||
if i == 0:
|
if i == 0:
|
||||||
args.command[i] = './' + args.command[i].split('/')[-1]
|
args.command[i] = './' + args.command[i].split('/')[-1]
|
||||||
else:
|
else:
|
||||||
args.command[i] = args.command[i].split('/')[-1]
|
args.command[i] = './' + args.command[i].split('/')[-1]
|
||||||
if args.command[0].endswith('.py'):
|
if args.command[0].endswith('.py'):
|
||||||
flst = [WRAPPER_PATH + '/rabit.py',
|
flst = [WRAPPER_PATH + '/rabit.py',
|
||||||
WRAPPER_PATH + '/librabit_wrapper.so',
|
WRAPPER_PATH + '/librabit_wrapper.so',
|
||||||
@ -112,6 +120,8 @@ def submit_yarn(nworker, worker_args, worker_env):
|
|||||||
env['rabit_cpu_vcores'] = str(args.vcores)
|
env['rabit_cpu_vcores'] = str(args.vcores)
|
||||||
env['rabit_memory_mb'] = str(args.memory_mb)
|
env['rabit_memory_mb'] = str(args.memory_mb)
|
||||||
env['rabit_world_size'] = str(args.nworker)
|
env['rabit_world_size'] = str(args.nworker)
|
||||||
|
env['rabit_hdfs_opts'] = str(args.libhdfs_opts)
|
||||||
|
env['rabit_hdfs_namenode'] = str(args.name_node)
|
||||||
|
|
||||||
if args.files != None:
|
if args.files != None:
|
||||||
for flst in args.files:
|
for flst in args.files:
|
||||||
@ -121,7 +131,8 @@ def submit_yarn(nworker, worker_args, worker_env):
|
|||||||
cmd += ' -file %s' % f
|
cmd += ' -file %s' % f
|
||||||
cmd += ' -jobname %s ' % args.jobname
|
cmd += ' -jobname %s ' % args.jobname
|
||||||
cmd += ' -tempdir %s ' % args.tempdir
|
cmd += ' -tempdir %s ' % args.tempdir
|
||||||
cmd += (' '.join(args.command + worker_args))
|
cmd += ' -queue %s ' % args.queue
|
||||||
|
cmd += (' '.join(['./run_hdfs_prog.py'] + args.command + worker_args))
|
||||||
if args.verbose != 0:
|
if args.verbose != 0:
|
||||||
print cmd
|
print cmd
|
||||||
subprocess.check_call(cmd, shell = True, env = env)
|
subprocess.check_call(cmd, shell = True, env = env)
|
||||||
|
|||||||
@ -1 +0,0 @@
|
|||||||
foler used to hold generated class files
|
|
||||||
@ -1,8 +1,8 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
if [ -z "$HADOOP_PREFIX" ]; then
|
if [ ! -d bin ]; then
|
||||||
echo "cannot found $HADOOP_PREFIX in the environment variable, please set it properly"
|
mkdir bin
|
||||||
exit 1
|
|
||||||
fi
|
fi
|
||||||
CPATH=`${HADOOP_PREFIX}/bin/hadoop classpath`
|
|
||||||
|
CPATH=`${HADOOP_HOME}/bin/hadoop classpath`
|
||||||
javac -cp $CPATH -d bin src/org/apache/hadoop/yarn/rabit/*
|
javac -cp $CPATH -d bin src/org/apache/hadoop/yarn/rabit/*
|
||||||
jar cf rabit-yarn.jar -C bin .
|
jar cf rabit-yarn.jar -C bin .
|
||||||
|
|||||||
45
subtree/rabit/yarn/run_hdfs_prog.py
Executable file
45
subtree/rabit/yarn/run_hdfs_prog.py
Executable file
@ -0,0 +1,45 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
this script helps setup classpath env for HDFS, before running program
|
||||||
|
that links with libhdfs
|
||||||
|
"""
|
||||||
|
import glob
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print 'Usage: the command you want to run'
|
||||||
|
|
||||||
|
hadoop_home = os.getenv('HADOOP_HOME')
|
||||||
|
hdfs_home = os.getenv('HADOOP_HDFS_HOME')
|
||||||
|
java_home = os.getenv('JAVA_HOME')
|
||||||
|
if hadoop_home is None:
|
||||||
|
hadoop_home = os.getenv('HADOOP_PREFIX')
|
||||||
|
assert hadoop_home is not None, 'need to set HADOOP_HOME'
|
||||||
|
assert hdfs_home is not None, 'need to set HADOOP_HDFS_HOME'
|
||||||
|
assert java_home is not None, 'need to set JAVA_HOME'
|
||||||
|
|
||||||
|
(classpath, err) = subprocess.Popen('%s/bin/hadoop classpath' % hadoop_home,
|
||||||
|
stdout=subprocess.PIPE, shell = True,
|
||||||
|
env = os.environ).communicate()
|
||||||
|
cpath = []
|
||||||
|
for f in classpath.split(':'):
|
||||||
|
cpath += glob.glob(f)
|
||||||
|
|
||||||
|
lpath = []
|
||||||
|
lpath.append('%s/lib/native' % hdfs_home)
|
||||||
|
lpath.append('%s/jre/lib/amd64/server' % java_home)
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
env['CLASSPATH'] = '${CLASSPATH}:' + (':'.join(cpath))
|
||||||
|
|
||||||
|
# setup hdfs options
|
||||||
|
if 'rabit_hdfs_opts' in env:
|
||||||
|
env['LIBHDFS_OPTS'] = env['rabit_hdfs_opts']
|
||||||
|
elif 'LIBHDFS_OPTS' not in env:
|
||||||
|
env['LIBHDFS_OPTS'] = '--Xmx128m'
|
||||||
|
|
||||||
|
env['LD_LIBRARY_PATH'] = '${LD_LIBRARY_PATH}:' + (':'.join(lpath))
|
||||||
|
ret = subprocess.call(args = sys.argv[1:], env = env)
|
||||||
|
sys.exit(ret)
|
||||||
@ -1,5 +1,6 @@
|
|||||||
package org.apache.hadoop.yarn.rabit;
|
package org.apache.hadoop.yarn.rabit;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@ -7,6 +8,7 @@ import java.util.Map;
|
|||||||
import java.util.Queue;
|
import java.util.Queue;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
import org.apache.hadoop.conf.Configuration;
|
import org.apache.hadoop.conf.Configuration;
|
||||||
@ -34,6 +36,7 @@ import org.apache.hadoop.yarn.api.records.NodeReport;
|
|||||||
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest;
|
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest;
|
||||||
import org.apache.hadoop.yarn.client.api.async.NMClientAsync;
|
import org.apache.hadoop.yarn.client.api.async.NMClientAsync;
|
||||||
import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
|
import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
|
||||||
|
import org.apache.hadoop.security.UserGroupInformation;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* application master for allocating resources of rabit client
|
* application master for allocating resources of rabit client
|
||||||
@ -61,6 +64,8 @@ public class ApplicationMaster {
|
|||||||
// command to launch
|
// command to launch
|
||||||
private String command = "";
|
private String command = "";
|
||||||
|
|
||||||
|
// username
|
||||||
|
private String userName = "";
|
||||||
// application tracker hostname
|
// application tracker hostname
|
||||||
private String appHostName = "";
|
private String appHostName = "";
|
||||||
// tracker URL to do
|
// tracker URL to do
|
||||||
@ -128,6 +133,8 @@ public class ApplicationMaster {
|
|||||||
*/
|
*/
|
||||||
private void initArgs(String args[]) throws IOException {
|
private void initArgs(String args[]) throws IOException {
|
||||||
LOG.info("Invoke initArgs");
|
LOG.info("Invoke initArgs");
|
||||||
|
// get user name
|
||||||
|
userName = UserGroupInformation.getCurrentUser().getShortUserName();
|
||||||
// cached maps
|
// cached maps
|
||||||
Map<String, Path> cacheFiles = new java.util.HashMap<String, Path>();
|
Map<String, Path> cacheFiles = new java.util.HashMap<String, Path>();
|
||||||
for (int i = 0; i < args.length; ++i) {
|
for (int i = 0; i < args.length; ++i) {
|
||||||
@ -156,7 +163,8 @@ public class ApplicationMaster {
|
|||||||
numVCores = this.getEnvInteger("rabit_cpu_vcores", true, numVCores);
|
numVCores = this.getEnvInteger("rabit_cpu_vcores", true, numVCores);
|
||||||
numMemoryMB = this.getEnvInteger("rabit_memory_mb", true, numMemoryMB);
|
numMemoryMB = this.getEnvInteger("rabit_memory_mb", true, numMemoryMB);
|
||||||
numTasks = this.getEnvInteger("rabit_world_size", true, numTasks);
|
numTasks = this.getEnvInteger("rabit_world_size", true, numTasks);
|
||||||
maxNumAttempt = this.getEnvInteger("rabit_max_attempt", false, maxNumAttempt);
|
maxNumAttempt = this.getEnvInteger("rabit_max_attempt", false,
|
||||||
|
maxNumAttempt);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -175,7 +183,7 @@ public class ApplicationMaster {
|
|||||||
RegisterApplicationMasterResponse response = this.rmClient
|
RegisterApplicationMasterResponse response = this.rmClient
|
||||||
.registerApplicationMaster(this.appHostName,
|
.registerApplicationMaster(this.appHostName,
|
||||||
this.appTrackerPort, this.appTrackerUrl);
|
this.appTrackerPort, this.appTrackerUrl);
|
||||||
|
|
||||||
boolean success = false;
|
boolean success = false;
|
||||||
String diagnostics = "";
|
String diagnostics = "";
|
||||||
try {
|
try {
|
||||||
@ -208,19 +216,20 @@ public class ApplicationMaster {
|
|||||||
assert (killedTasks.size() + finishedTasks.size() == numTasks);
|
assert (killedTasks.size() + finishedTasks.size() == numTasks);
|
||||||
success = finishedTasks.size() == numTasks;
|
success = finishedTasks.size() == numTasks;
|
||||||
LOG.info("Application completed. Stopping running containers");
|
LOG.info("Application completed. Stopping running containers");
|
||||||
nmClient.stop();
|
|
||||||
diagnostics = "Diagnostics." + ", num_tasks" + this.numTasks
|
diagnostics = "Diagnostics." + ", num_tasks" + this.numTasks
|
||||||
+ ", finished=" + this.finishedTasks.size() + ", failed="
|
+ ", finished=" + this.finishedTasks.size() + ", failed="
|
||||||
+ this.killedTasks.size() + "\n" + this.abortDiagnosis;
|
+ this.killedTasks.size() + "\n" + this.abortDiagnosis;
|
||||||
|
nmClient.stop();
|
||||||
LOG.info(diagnostics);
|
LOG.info(diagnostics);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
diagnostics = e.toString();
|
diagnostics = e.toString();
|
||||||
}
|
}
|
||||||
rmClient.unregisterApplicationMaster(
|
rmClient.unregisterApplicationMaster(
|
||||||
success ? FinalApplicationStatus.SUCCEEDED
|
success ? FinalApplicationStatus.SUCCEEDED
|
||||||
: FinalApplicationStatus.FAILED, diagnostics,
|
: FinalApplicationStatus.FAILED, diagnostics,
|
||||||
appTrackerUrl);
|
appTrackerUrl);
|
||||||
if (!success) throw new Exception("Application not successful");
|
if (!success)
|
||||||
|
throw new Exception("Application not successful");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -265,30 +274,63 @@ public class ApplicationMaster {
|
|||||||
task.containerRequest = null;
|
task.containerRequest = null;
|
||||||
ContainerLaunchContext ctx = Records
|
ContainerLaunchContext ctx = Records
|
||||||
.newRecord(ContainerLaunchContext.class);
|
.newRecord(ContainerLaunchContext.class);
|
||||||
String cmd =
|
String hadoop = "hadoop";
|
||||||
// use this to setup CLASSPATH correctly for libhdfs
|
if (System.getenv("HADOOP_HOME") != null) {
|
||||||
"CLASSPATH=${CLASSPATH}:`${HADOOP_PREFIX}/bin/hadoop classpath --glob` "
|
hadoop = "${HADOOP_HOME}/bin/hadoop";
|
||||||
+ this.command + " 1>"
|
} else if (System.getenv("HADOOP_PREFIX") != null) {
|
||||||
+ ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
|
hadoop = "${HADOOP_PREFIX}/bin/hadoop";
|
||||||
+ " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR
|
}
|
||||||
+ "/stderr";
|
|
||||||
LOG.info(cmd);
|
String cmd =
|
||||||
|
// use this to setup CLASSPATH correctly for libhdfs
|
||||||
|
this.command + " 1>"
|
||||||
|
+ ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
|
||||||
|
+ " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR
|
||||||
|
+ "/stderr";
|
||||||
ctx.setCommands(Collections.singletonList(cmd));
|
ctx.setCommands(Collections.singletonList(cmd));
|
||||||
LOG.info(workerResources);
|
LOG.info(workerResources);
|
||||||
ctx.setLocalResources(this.workerResources);
|
ctx.setLocalResources(this.workerResources);
|
||||||
// setup environment variables
|
// setup environment variables
|
||||||
Map<String, String> env = new java.util.HashMap<String, String>();
|
Map<String, String> env = new java.util.HashMap<String, String>();
|
||||||
|
|
||||||
// setup class path, this is kind of duplicated, ignoring
|
// setup class path, this is kind of duplicated, ignoring
|
||||||
StringBuilder cpath = new StringBuilder("${CLASSPATH}:./*");
|
StringBuilder cpath = new StringBuilder("${CLASSPATH}:./*");
|
||||||
for (String c : conf.getStrings(
|
for (String c : conf.getStrings(
|
||||||
YarnConfiguration.YARN_APPLICATION_CLASSPATH,
|
YarnConfiguration.YARN_APPLICATION_CLASSPATH,
|
||||||
YarnConfiguration.DEFAULT_YARN_APPLICATION_CLASSPATH)) {
|
YarnConfiguration.DEFAULT_YARN_APPLICATION_CLASSPATH)) {
|
||||||
cpath.append(':');
|
String[] arrPath = c.split(":");
|
||||||
cpath.append(c.trim());
|
for (String ps : arrPath) {
|
||||||
|
if (ps.endsWith("*.jar") || ps.endsWith("*")) {
|
||||||
|
ps = ps.substring(0, ps.lastIndexOf('*'));
|
||||||
|
String prefix = ps.substring(0, ps.lastIndexOf('/'));
|
||||||
|
if (ps.startsWith("$")) {
|
||||||
|
String[] arr =ps.split("/", 2);
|
||||||
|
if (arr.length != 2) continue;
|
||||||
|
try {
|
||||||
|
ps = System.getenv(arr[0].substring(1)) + '/' + arr[1];
|
||||||
|
} catch (Exception e){
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
File dir = new File(ps);
|
||||||
|
if (dir.isDirectory()) {
|
||||||
|
for (File f: dir.listFiles()) {
|
||||||
|
if (f.isFile() && f.getPath().endsWith(".jar")) {
|
||||||
|
cpath.append(":");
|
||||||
|
cpath.append(prefix + '/' + f.getName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cpath.append(':');
|
||||||
|
cpath.append(ps.trim());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// already use hadoop command to get class path in worker, maybe a better solution in future
|
// already use hadoop command to get class path in worker, maybe a
|
||||||
// env.put("CLASSPATH", cpath.toString());
|
// better solution in future
|
||||||
|
env.put("CLASSPATH", cpath.toString());
|
||||||
|
//LOG.info("CLASSPATH =" + cpath.toString());
|
||||||
// setup LD_LIBARY_PATH path for libhdfs
|
// setup LD_LIBARY_PATH path for libhdfs
|
||||||
env.put("LD_LIBRARY_PATH",
|
env.put("LD_LIBRARY_PATH",
|
||||||
"${LD_LIBRARY_PATH}:$HADOOP_HDFS_HOME/lib/native:$JAVA_HOME/jre/lib/amd64/server");
|
"${LD_LIBRARY_PATH}:$HADOOP_HDFS_HOME/lib/native:$JAVA_HOME/jre/lib/amd64/server");
|
||||||
@ -298,10 +340,13 @@ public class ApplicationMaster {
|
|||||||
if (e.getKey().startsWith("rabit_")) {
|
if (e.getKey().startsWith("rabit_")) {
|
||||||
env.put(e.getKey(), e.getValue());
|
env.put(e.getKey(), e.getValue());
|
||||||
}
|
}
|
||||||
|
if (e.getKey() == "LIBHDFS_OPTS") {
|
||||||
|
env.put(e.getKey(), e.getValue());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
env.put("rabit_task_id", String.valueOf(task.taskId));
|
env.put("rabit_task_id", String.valueOf(task.taskId));
|
||||||
env.put("rabit_num_trial", String.valueOf(task.attemptCounter));
|
env.put("rabit_num_trial", String.valueOf(task.attemptCounter));
|
||||||
|
// ctx.setUser(userName);
|
||||||
ctx.setEnvironment(env);
|
ctx.setEnvironment(env);
|
||||||
synchronized (this) {
|
synchronized (this) {
|
||||||
assert (!this.runningTasks.containsKey(container.getId()));
|
assert (!this.runningTasks.containsKey(container.getId()));
|
||||||
@ -376,8 +421,17 @@ public class ApplicationMaster {
|
|||||||
Collection<TaskRecord> tasks = new java.util.LinkedList<TaskRecord>();
|
Collection<TaskRecord> tasks = new java.util.LinkedList<TaskRecord>();
|
||||||
for (ContainerId cid : failed) {
|
for (ContainerId cid : failed) {
|
||||||
TaskRecord r = runningTasks.remove(cid);
|
TaskRecord r = runningTasks.remove(cid);
|
||||||
if (r == null)
|
if (r == null) {
|
||||||
continue;
|
continue;
|
||||||
|
}
|
||||||
|
LOG.info("Task "
|
||||||
|
+ r.taskId
|
||||||
|
+ "failed on "
|
||||||
|
+ r.container.getId()
|
||||||
|
+ ". See LOG at : "
|
||||||
|
+ String.format("http://%s/node/containerlogs/%s/"
|
||||||
|
+ userName, r.container.getNodeHttpAddress(),
|
||||||
|
r.container.getId()));
|
||||||
r.attemptCounter += 1;
|
r.attemptCounter += 1;
|
||||||
r.container = null;
|
r.container = null;
|
||||||
tasks.add(r);
|
tasks.add(r);
|
||||||
@ -411,22 +465,26 @@ public class ApplicationMaster {
|
|||||||
finishedTasks.add(r);
|
finishedTasks.add(r);
|
||||||
runningTasks.remove(s.getContainerId());
|
runningTasks.remove(s.getContainerId());
|
||||||
} else {
|
} else {
|
||||||
switch (exstatus) {
|
try {
|
||||||
case ContainerExitStatus.KILLED_EXCEEDED_PMEM:
|
if (exstatus == ContainerExitStatus.class.getField(
|
||||||
this.abortJob("[Rabit] Task "
|
"KILLED_EXCEEDED_PMEM").getInt(null)) {
|
||||||
+ r.taskId
|
this.abortJob("[Rabit] Task "
|
||||||
+ " killed because of exceeding allocated physical memory");
|
+ r.taskId
|
||||||
break;
|
+ " killed because of exceeding allocated physical memory");
|
||||||
case ContainerExitStatus.KILLED_EXCEEDED_VMEM:
|
continue;
|
||||||
this.abortJob("[Rabit] Task "
|
}
|
||||||
+ r.taskId
|
if (exstatus == ContainerExitStatus.class.getField(
|
||||||
+ " killed because of exceeding allocated virtual memory");
|
"KILLED_EXCEEDED_VMEM").getInt(null)) {
|
||||||
break;
|
this.abortJob("[Rabit] Task "
|
||||||
default:
|
+ r.taskId
|
||||||
LOG.info("[Rabit] Task " + r.taskId
|
+ " killed because of exceeding allocated virtual memory");
|
||||||
+ " exited with status " + exstatus);
|
continue;
|
||||||
failed.add(s.getContainerId());
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
}
|
}
|
||||||
|
LOG.info("[Rabit] Task " + r.taskId + " exited with status "
|
||||||
|
+ exstatus + " Diagnostics:"+ s.getDiagnostics());
|
||||||
|
failed.add(s.getContainerId());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
this.handleFailure(failed);
|
this.handleFailure(failed);
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import org.apache.hadoop.fs.Path;
|
|||||||
import org.apache.hadoop.fs.FileStatus;
|
import org.apache.hadoop.fs.FileStatus;
|
||||||
import org.apache.hadoop.fs.FileSystem;
|
import org.apache.hadoop.fs.FileSystem;
|
||||||
import org.apache.hadoop.fs.permission.FsPermission;
|
import org.apache.hadoop.fs.permission.FsPermission;
|
||||||
|
import org.apache.hadoop.security.UserGroupInformation;
|
||||||
import org.apache.hadoop.yarn.api.ApplicationConstants;
|
import org.apache.hadoop.yarn.api.ApplicationConstants;
|
||||||
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
||||||
import org.apache.hadoop.yarn.api.records.ApplicationReport;
|
import org.apache.hadoop.yarn.api.records.ApplicationReport;
|
||||||
@ -19,6 +20,7 @@ import org.apache.hadoop.yarn.api.records.LocalResource;
|
|||||||
import org.apache.hadoop.yarn.api.records.LocalResourceType;
|
import org.apache.hadoop.yarn.api.records.LocalResourceType;
|
||||||
import org.apache.hadoop.yarn.api.records.LocalResourceVisibility;
|
import org.apache.hadoop.yarn.api.records.LocalResourceVisibility;
|
||||||
import org.apache.hadoop.yarn.api.records.Resource;
|
import org.apache.hadoop.yarn.api.records.Resource;
|
||||||
|
import org.apache.hadoop.yarn.api.records.QueueInfo;
|
||||||
import org.apache.hadoop.yarn.api.records.YarnApplicationState;
|
import org.apache.hadoop.yarn.api.records.YarnApplicationState;
|
||||||
import org.apache.hadoop.yarn.client.api.YarnClient;
|
import org.apache.hadoop.yarn.client.api.YarnClient;
|
||||||
import org.apache.hadoop.yarn.client.api.YarnClientApplication;
|
import org.apache.hadoop.yarn.client.api.YarnClientApplication;
|
||||||
@ -43,14 +45,21 @@ public class Client {
|
|||||||
private String appArgs = "";
|
private String appArgs = "";
|
||||||
// HDFS Path to store temporal result
|
// HDFS Path to store temporal result
|
||||||
private String tempdir = "/tmp";
|
private String tempdir = "/tmp";
|
||||||
|
// user name
|
||||||
|
private String userName = "";
|
||||||
// job name
|
// job name
|
||||||
private String jobName = "";
|
private String jobName = "";
|
||||||
|
// queue
|
||||||
|
private String queue = "default";
|
||||||
/**
|
/**
|
||||||
* constructor
|
* constructor
|
||||||
* @throws IOException
|
* @throws IOException
|
||||||
*/
|
*/
|
||||||
private Client() throws IOException {
|
private Client() throws IOException {
|
||||||
|
conf.addResource(new Path(System.getenv("HADOOP_CONF_DIR") +"/core-site.xml"));
|
||||||
|
conf.addResource(new Path(System.getenv("HADOOP_CONF_DIR") +"/hdfs-site.xml"));
|
||||||
dfs = FileSystem.get(conf);
|
dfs = FileSystem.get(conf);
|
||||||
|
userName = UserGroupInformation.getCurrentUser().getShortUserName();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -127,6 +136,9 @@ public class Client {
|
|||||||
if (e.getKey().startsWith("rabit_")) {
|
if (e.getKey().startsWith("rabit_")) {
|
||||||
env.put(e.getKey(), e.getValue());
|
env.put(e.getKey(), e.getValue());
|
||||||
}
|
}
|
||||||
|
if (e.getKey() == "LIBHDFS_OPTS") {
|
||||||
|
env.put(e.getKey(), e.getValue());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
LOG.debug(env);
|
LOG.debug(env);
|
||||||
return env;
|
return env;
|
||||||
@ -152,6 +164,8 @@ public class Client {
|
|||||||
this.jobName = args[++i];
|
this.jobName = args[++i];
|
||||||
} else if(args[i].equals("-tempdir")) {
|
} else if(args[i].equals("-tempdir")) {
|
||||||
this.tempdir = args[++i];
|
this.tempdir = args[++i];
|
||||||
|
} else if(args[i].equals("-queue")) {
|
||||||
|
this.queue = args[++i];
|
||||||
} else {
|
} else {
|
||||||
sargs.append(" ");
|
sargs.append(" ");
|
||||||
sargs.append(args[i]);
|
sargs.append(args[i]);
|
||||||
@ -168,7 +182,6 @@ public class Client {
|
|||||||
}
|
}
|
||||||
this.initArgs(args);
|
this.initArgs(args);
|
||||||
// Create yarnClient
|
// Create yarnClient
|
||||||
YarnConfiguration conf = new YarnConfiguration();
|
|
||||||
YarnClient yarnClient = YarnClient.createYarnClient();
|
YarnClient yarnClient = YarnClient.createYarnClient();
|
||||||
yarnClient.init(conf);
|
yarnClient.init(conf);
|
||||||
yarnClient.start();
|
yarnClient.start();
|
||||||
@ -181,13 +194,14 @@ public class Client {
|
|||||||
.newRecord(ContainerLaunchContext.class);
|
.newRecord(ContainerLaunchContext.class);
|
||||||
ApplicationSubmissionContext appContext = app
|
ApplicationSubmissionContext appContext = app
|
||||||
.getApplicationSubmissionContext();
|
.getApplicationSubmissionContext();
|
||||||
|
|
||||||
// Submit application
|
// Submit application
|
||||||
ApplicationId appId = appContext.getApplicationId();
|
ApplicationId appId = appContext.getApplicationId();
|
||||||
// setup cache-files and environment variables
|
// setup cache-files and environment variables
|
||||||
amContainer.setLocalResources(this.setupCacheFiles(appId));
|
amContainer.setLocalResources(this.setupCacheFiles(appId));
|
||||||
amContainer.setEnvironment(this.getEnvironment());
|
amContainer.setEnvironment(this.getEnvironment());
|
||||||
String cmd = "$JAVA_HOME/bin/java"
|
String cmd = "$JAVA_HOME/bin/java"
|
||||||
+ " -Xmx256M"
|
+ " -Xmx900M"
|
||||||
+ " org.apache.hadoop.yarn.rabit.ApplicationMaster"
|
+ " org.apache.hadoop.yarn.rabit.ApplicationMaster"
|
||||||
+ this.cacheFileArg + ' ' + this.appArgs + " 1>"
|
+ this.cacheFileArg + ' ' + this.appArgs + " 1>"
|
||||||
+ ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
|
+ ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
|
||||||
@ -197,15 +211,15 @@ public class Client {
|
|||||||
|
|
||||||
// Set up resource type requirements for ApplicationMaster
|
// Set up resource type requirements for ApplicationMaster
|
||||||
Resource capability = Records.newRecord(Resource.class);
|
Resource capability = Records.newRecord(Resource.class);
|
||||||
capability.setMemory(256);
|
capability.setMemory(1024);
|
||||||
capability.setVirtualCores(1);
|
capability.setVirtualCores(1);
|
||||||
LOG.info("jobname=" + this.jobName);
|
LOG.info("jobname=" + this.jobName + ",username=" + this.userName);
|
||||||
|
|
||||||
appContext.setApplicationName(jobName + ":RABIT-YARN");
|
appContext.setApplicationName(jobName + ":RABIT-YARN");
|
||||||
appContext.setAMContainerSpec(amContainer);
|
appContext.setAMContainerSpec(amContainer);
|
||||||
appContext.setResource(capability);
|
appContext.setResource(capability);
|
||||||
appContext.setQueue("default");
|
appContext.setQueue(queue);
|
||||||
|
//appContext.setUser(userName);
|
||||||
LOG.info("Submitting application " + appId);
|
LOG.info("Submitting application " + appId);
|
||||||
yarnClient.submitApplication(appContext);
|
yarnClient.submitApplication(appContext);
|
||||||
|
|
||||||
@ -218,12 +232,16 @@ public class Client {
|
|||||||
appReport = yarnClient.getApplicationReport(appId);
|
appReport = yarnClient.getApplicationReport(appId);
|
||||||
appState = appReport.getYarnApplicationState();
|
appState = appReport.getYarnApplicationState();
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("Application " + appId + " finished with"
|
System.out.println("Application " + appId + " finished with"
|
||||||
+ " state " + appState + " at " + appReport.getFinishTime());
|
+ " state " + appState + " at " + appReport.getFinishTime());
|
||||||
if (!appReport.getFinalApplicationStatus().equals(
|
if (!appReport.getFinalApplicationStatus().equals(
|
||||||
FinalApplicationStatus.SUCCEEDED)) {
|
FinalApplicationStatus.SUCCEEDED)) {
|
||||||
System.err.println(appReport.getDiagnostics());
|
System.err.println(appReport.getDiagnostics());
|
||||||
|
System.out.println("Available queues:");
|
||||||
|
for (QueueInfo q : yarnClient.getAllQueues()) {
|
||||||
|
System.out.println(q.getQueueName());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user