add ringbased passing and batch schedule
This commit is contained in:
parent
295d8a12f1
commit
c57dad8b17
@ -19,6 +19,7 @@ class FileStream : public utils::ISeekStream {
|
|||||||
public:
|
public:
|
||||||
explicit FileStream(const char *fname, const char *mode)
|
explicit FileStream(const char *fname, const char *mode)
|
||||||
: use_stdio(false) {
|
: use_stdio(false) {
|
||||||
|
using namespace std;
|
||||||
#ifndef RABIT_STRICT_CXX98_
|
#ifndef RABIT_STRICT_CXX98_
|
||||||
if (!strcmp(fname, "stdin")) {
|
if (!strcmp(fname, "stdin")) {
|
||||||
use_stdio = true; fp = stdin;
|
use_stdio = true; fp = stdin;
|
||||||
@ -51,7 +52,7 @@ class FileStream : public utils::ISeekStream {
|
|||||||
return std::ftell(fp);
|
return std::ftell(fp);
|
||||||
}
|
}
|
||||||
virtual bool AtEnd(void) const {
|
virtual bool AtEnd(void) const {
|
||||||
return feof(fp) != 0;
|
return std::feof(fp) != 0;
|
||||||
}
|
}
|
||||||
inline void Close(void) {
|
inline void Close(void) {
|
||||||
if (fp != NULL && !use_stdio) {
|
if (fp != NULL && !use_stdio) {
|
||||||
@ -60,7 +61,7 @@ class FileStream : public utils::ISeekStream {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FILE *fp;
|
std::FILE *fp;
|
||||||
bool use_stdio;
|
bool use_stdio;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -71,7 +72,7 @@ class FileSplit : public LineSplitBase {
|
|||||||
LineSplitBase::SplitNames(&fnames_, uri, "#");
|
LineSplitBase::SplitNames(&fnames_, uri, "#");
|
||||||
std::vector<size_t> fsize;
|
std::vector<size_t> fsize;
|
||||||
for (size_t i = 0; i < fnames_.size(); ++i) {
|
for (size_t i = 0; i < fnames_.size(); ++i) {
|
||||||
if (!strncmp(fnames_[i].c_str(), "file://", 7)) {
|
if (!std::strncmp(fnames_[i].c_str(), "file://", 7)) {
|
||||||
std::string tmp = fnames_[i].c_str() + 7;
|
std::string tmp = fnames_[i].c_str() + 7;
|
||||||
fnames_[i] = tmp;
|
fnames_[i] = tmp;
|
||||||
}
|
}
|
||||||
@ -88,11 +89,11 @@ class FileSplit : public LineSplitBase {
|
|||||||
}
|
}
|
||||||
// get file size
|
// get file size
|
||||||
inline static size_t GetFileSize(const char *fname) {
|
inline static size_t GetFileSize(const char *fname) {
|
||||||
FILE *fp = utils::FopenCheck(fname, "rb");
|
std::FILE *fp = utils::FopenCheck(fname, "rb");
|
||||||
// NOTE: fseek may not be good, but serves as ok solution
|
// NOTE: fseek may not be good, but serves as ok solution
|
||||||
fseek(fp, 0, SEEK_END);
|
std::fseek(fp, 0, SEEK_END);
|
||||||
size_t fsize = static_cast<size_t>(ftell(fp));
|
size_t fsize = static_cast<size_t>(std::ftell(fp));
|
||||||
fclose(fp);
|
std::fclose(fp);
|
||||||
return fsize;
|
return fsize;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -25,6 +25,7 @@ namespace io {
|
|||||||
inline InputSplit *CreateInputSplit(const char *uri,
|
inline InputSplit *CreateInputSplit(const char *uri,
|
||||||
unsigned part,
|
unsigned part,
|
||||||
unsigned nsplit) {
|
unsigned nsplit) {
|
||||||
|
using namespace std;
|
||||||
if (!strcmp(uri, "stdin")) {
|
if (!strcmp(uri, "stdin")) {
|
||||||
return new SingleFileSplit(uri);
|
return new SingleFileSplit(uri);
|
||||||
}
|
}
|
||||||
@ -48,6 +49,7 @@ inline InputSplit *CreateInputSplit(const char *uri,
|
|||||||
* \param mode can be 'w' or 'r' for read or write
|
* \param mode can be 'w' or 'r' for read or write
|
||||||
*/
|
*/
|
||||||
inline IStream *CreateStream(const char *uri, const char *mode) {
|
inline IStream *CreateStream(const char *uri, const char *mode) {
|
||||||
|
using namespace std;
|
||||||
if (!strncmp(uri, "file://", 7)) {
|
if (!strncmp(uri, "file://", 7)) {
|
||||||
return new FileStream(uri + 7, mode);
|
return new FileStream(uri + 7, mode);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
#ifndef RABIT_LEARN_IO_LINE_SPLIT_INL_H_
|
#ifndef RABIT_LEARN_IO_LINE_SPLIT_INL_H_
|
||||||
#define RABIT_LEARN_IO_LINE_SPLIT_INL_H_
|
#define RABIT_LEARN_IO_LINE_SPLIT_INL_H_
|
||||||
/*!
|
/*!
|
||||||
* \file line_split-inl.h
|
* \std::FILE line_split-inl.h
|
||||||
* \brief base implementation of line-spliter
|
* \brief base implementation of line-spliter
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
@ -30,7 +30,7 @@ class LineSplitBase : public InputSplit {
|
|||||||
if (out_data->length() != 0) return true;
|
if (out_data->length() != 0) return true;
|
||||||
file_ptr_ += 1;
|
file_ptr_ += 1;
|
||||||
if (offset_curr_ != file_offset_[file_ptr_]) {
|
if (offset_curr_ != file_offset_[file_ptr_]) {
|
||||||
utils::Error("warning:file size not calculated correctly\n");
|
utils::Error("warning:std::FILE size not calculated correctly\n");
|
||||||
offset_curr_ = file_offset_[file_ptr_];
|
offset_curr_ = file_offset_[file_ptr_];
|
||||||
}
|
}
|
||||||
if (offset_curr_ >= offset_end_) return false;
|
if (offset_curr_ >= offset_end_) return false;
|
||||||
@ -59,7 +59,7 @@ class LineSplitBase : public InputSplit {
|
|||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief initialize the line spliter,
|
* \brief initialize the line spliter,
|
||||||
* \param file_size, size of each files
|
* \param file_size, size of each std::FILEs
|
||||||
* \param rank the current rank of the data
|
* \param rank the current rank of the data
|
||||||
* \param nsplit number of split we will divide the data into
|
* \param nsplit number of split we will divide the data into
|
||||||
*/
|
*/
|
||||||
@ -96,31 +96,31 @@ class LineSplitBase : public InputSplit {
|
|||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief get the seek stream of given file_index
|
* \brief get the seek stream of given file_index
|
||||||
* \return the corresponding seek stream at head of file
|
* \return the corresponding seek stream at head of std::FILE
|
||||||
*/
|
*/
|
||||||
virtual utils::ISeekStream *GetFile(size_t file_index) = 0;
|
virtual utils::ISeekStream *GetFile(size_t file_index) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief split names given
|
* \brief split names given
|
||||||
* \param out_fname output file names
|
* \param out_fname output std::FILE names
|
||||||
* \param uri_ the iput uri file
|
* \param uri_ the iput uri std::FILE
|
||||||
* \param dlm deliminetr
|
* \param dlm deliminetr
|
||||||
*/
|
*/
|
||||||
inline static void SplitNames(std::vector<std::string> *out_fname,
|
inline static void SplitNames(std::vector<std::string> *out_fname,
|
||||||
const char *uri_,
|
const char *uri_,
|
||||||
const char *dlm) {
|
const char *dlm) {
|
||||||
std::string uri = uri_;
|
std::string uri = uri_;
|
||||||
char *p = strtok(BeginPtr(uri), dlm);
|
char *p = std::strtok(BeginPtr(uri), dlm);
|
||||||
while (p != NULL) {
|
while (p != NULL) {
|
||||||
out_fname->push_back(std::string(p));
|
out_fname->push_back(std::string(p));
|
||||||
p = strtok(NULL, dlm);
|
p = std::strtok(NULL, dlm);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
private:
|
private:
|
||||||
/*! \brief current input stream */
|
/*! \brief current input stream */
|
||||||
utils::ISeekStream *fs_;
|
utils::ISeekStream *fs_;
|
||||||
/*! \brief file pointer of which file to read on */
|
/*! \brief std::FILE pointer of which std::FILE to read on */
|
||||||
size_t file_ptr_;
|
size_t file_ptr_;
|
||||||
/*! \brief file pointer where the end of file lies */
|
/*! \brief std::FILE pointer where the end of std::FILE lies */
|
||||||
size_t file_ptr_end_;
|
size_t file_ptr_end_;
|
||||||
/*! \brief get the current offset */
|
/*! \brief get the current offset */
|
||||||
size_t offset_curr_;
|
size_t offset_curr_;
|
||||||
@ -128,7 +128,7 @@ class LineSplitBase : public InputSplit {
|
|||||||
size_t offset_begin_;
|
size_t offset_begin_;
|
||||||
/*! \brief end of the offset */
|
/*! \brief end of the offset */
|
||||||
size_t offset_end_;
|
size_t offset_end_;
|
||||||
/*! \brief byte-offset of each file */
|
/*! \brief byte-offset of each std::FILE */
|
||||||
std::vector<size_t> file_offset_;
|
std::vector<size_t> file_offset_;
|
||||||
/*! \brief buffer reader */
|
/*! \brief buffer reader */
|
||||||
StreamBufferReader reader_;
|
StreamBufferReader reader_;
|
||||||
@ -136,11 +136,11 @@ class LineSplitBase : public InputSplit {
|
|||||||
const static size_t kBufferSize = 256;
|
const static size_t kBufferSize = 256;
|
||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief line split from single file */
|
/*! \brief line split from single std::FILE */
|
||||||
class SingleFileSplit : public InputSplit {
|
class SingleFileSplit : public InputSplit {
|
||||||
public:
|
public:
|
||||||
explicit SingleFileSplit(const char *fname) {
|
explicit SingleFileSplit(const char *fname) {
|
||||||
if (!strcmp(fname, "stdin")) {
|
if (!std::strcmp(fname, "stdin")) {
|
||||||
#ifndef RABIT_STRICT_CXX98_
|
#ifndef RABIT_STRICT_CXX98_
|
||||||
use_stdin_ = true; fp_ = stdin;
|
use_stdin_ = true; fp_ = stdin;
|
||||||
#endif
|
#endif
|
||||||
@ -151,13 +151,13 @@ class SingleFileSplit : public InputSplit {
|
|||||||
end_of_file_ = false;
|
end_of_file_ = false;
|
||||||
}
|
}
|
||||||
virtual ~SingleFileSplit(void) {
|
virtual ~SingleFileSplit(void) {
|
||||||
if (!use_stdin_) fclose(fp_);
|
if (!use_stdin_) std::fclose(fp_);
|
||||||
}
|
}
|
||||||
virtual bool NextLine(std::string *out_data) {
|
virtual bool NextLine(std::string *out_data) {
|
||||||
if (end_of_file_) return false;
|
if (end_of_file_) return false;
|
||||||
out_data->clear();
|
out_data->clear();
|
||||||
while (true) {
|
while (true) {
|
||||||
char c = fgetc(fp_);
|
char c = std::fgetc(fp_);
|
||||||
if (c == EOF) {
|
if (c == EOF) {
|
||||||
end_of_file_ = true;
|
end_of_file_ = true;
|
||||||
}
|
}
|
||||||
@ -172,7 +172,7 @@ class SingleFileSplit : public InputSplit {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FILE *fp_;
|
std::FILE *fp_;
|
||||||
bool use_stdin_;
|
bool use_stdin_;
|
||||||
bool end_of_file_;
|
bool end_of_file_;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -26,6 +26,9 @@ AllreduceBase::AllreduceBase(void) {
|
|||||||
world_size = -1;
|
world_size = -1;
|
||||||
hadoop_mode = 0;
|
hadoop_mode = 0;
|
||||||
version_number = 0;
|
version_number = 0;
|
||||||
|
// 32 K items
|
||||||
|
reduce_ring_mincount = 32 << 10;
|
||||||
|
// tracker URL
|
||||||
task_id = "NULL";
|
task_id = "NULL";
|
||||||
err_link = NULL;
|
err_link = NULL;
|
||||||
this->SetParam("rabit_reduce_buffer", "256MB");
|
this->SetParam("rabit_reduce_buffer", "256MB");
|
||||||
@ -33,7 +36,8 @@ AllreduceBase::AllreduceBase(void) {
|
|||||||
env_vars.push_back("rabit_task_id");
|
env_vars.push_back("rabit_task_id");
|
||||||
env_vars.push_back("rabit_num_trial");
|
env_vars.push_back("rabit_num_trial");
|
||||||
env_vars.push_back("rabit_reduce_buffer");
|
env_vars.push_back("rabit_reduce_buffer");
|
||||||
env_vars.push_back("rabit_tracker_uri");
|
env_vars.push_back("rabit_reduce_ring_mincount");
|
||||||
|
env_vars.push_back("rabit_tracker_uri");
|
||||||
env_vars.push_back("rabit_tracker_port");
|
env_vars.push_back("rabit_tracker_port");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -116,6 +120,27 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
|
|||||||
tracker.SendStr(msg);
|
tracker.SendStr(msg);
|
||||||
tracker.Close();
|
tracker.Close();
|
||||||
}
|
}
|
||||||
|
// util to parse data with unit suffix
|
||||||
|
inline size_t ParseUnit(const char *name, const char *val) {
|
||||||
|
char unit;
|
||||||
|
uint64_t amount;
|
||||||
|
int n = sscanf(val, "%lu%c", &amount, &unit);
|
||||||
|
if (n == 2) {
|
||||||
|
switch (unit) {
|
||||||
|
case 'B': return amount;
|
||||||
|
case 'K': return amount << 10UL;
|
||||||
|
case 'M': return amount << 20UL;
|
||||||
|
case 'G': return amount << 30UL;
|
||||||
|
default: utils::Error("invalid format for %s", name); return 0;
|
||||||
|
}
|
||||||
|
} else if (n == 1) {
|
||||||
|
return amount;
|
||||||
|
} else {
|
||||||
|
utils::Error("invalid format for %s," \
|
||||||
|
"shhould be {integer}{unit}, unit can be {B, KB, MB, GB}", name);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief set parameters to the engine
|
* \brief set parameters to the engine
|
||||||
* \param name parameter name
|
* \param name parameter name
|
||||||
@ -127,21 +152,11 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
if (!strcmp(name, "rabit_task_id")) task_id = val;
|
if (!strcmp(name, "rabit_task_id")) task_id = val;
|
||||||
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
|
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
|
||||||
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val);
|
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val);
|
||||||
|
if (!strcmp(name, "rabit_reduce_ring_mincount")) {
|
||||||
|
reduce_ring_mincount = ParseUnit(name, val);
|
||||||
|
}
|
||||||
if (!strcmp(name, "rabit_reduce_buffer")) {
|
if (!strcmp(name, "rabit_reduce_buffer")) {
|
||||||
char unit;
|
reduce_buffer_size = (ParseUnit(name, val) + 7) >> 3;
|
||||||
uint64_t amount;
|
|
||||||
if (sscanf(val, "%lu%c", &amount, &unit) == 2) {
|
|
||||||
switch (unit) {
|
|
||||||
case 'B': reduce_buffer_size = (amount + 7)/ 8; break;
|
|
||||||
case 'K': reduce_buffer_size = amount << 7UL; break;
|
|
||||||
case 'M': reduce_buffer_size = amount << 17UL; break;
|
|
||||||
case 'G': reduce_buffer_size = amount << 27UL; break;
|
|
||||||
default: utils::Error("invalid format for reduce buffer");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
utils::Error("invalid format for reduce_buffer,"\
|
|
||||||
"shhould be {integer}{unit}, unit can be {B, KB, MB, GB}");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -341,6 +356,28 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
|||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer) {
|
ReduceFunction reducer) {
|
||||||
|
if (count > reduce_ring_mincount) {
|
||||||
|
return this->TryAllreduceRing(sendrecvbuf_, type_nbytes, count, reducer);
|
||||||
|
} else {
|
||||||
|
return this->TryAllreduceTree(sendrecvbuf_, type_nbytes, count, reducer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief perform in-place allreduce, on sendrecvbuf,
|
||||||
|
* this function implements tree-shape reduction
|
||||||
|
*
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param reducer reduce function
|
||||||
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
|
* \sa ReturnType
|
||||||
|
*/
|
||||||
|
AllreduceBase::ReturnType
|
||||||
|
AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
ReduceFunction reducer) {
|
||||||
RefLinkVector &links = tree_links;
|
RefLinkVector &links = tree_links;
|
||||||
if (links.size() == 0 || count == 0) return kSuccess;
|
if (links.size() == 0 || count == 0) return kSuccess;
|
||||||
// total size of message
|
// total size of message
|
||||||
@ -599,5 +636,217 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
|||||||
}
|
}
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
/*!
|
||||||
|
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||||
|
* the data provided by current node k is [slice_begin, slice_end),
|
||||||
|
* the next node's segment must start with slice_end
|
||||||
|
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||||
|
* use a ring based algorithm
|
||||||
|
*
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||||
|
* \param total_size total size of data to be gathered
|
||||||
|
* \param slice_begin beginning of the current slice
|
||||||
|
* \param slice_end end of the current slice
|
||||||
|
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||||
|
*/
|
||||||
|
AllreduceBase::ReturnType
|
||||||
|
AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
|
||||||
|
size_t slice_begin,
|
||||||
|
size_t slice_end,
|
||||||
|
size_t size_prev_slice) {
|
||||||
|
// read from next link and send to prev one
|
||||||
|
LinkRecord &prev = *ring_prev, &next = *ring_next;
|
||||||
|
// need to reply on special rank structure
|
||||||
|
utils::Assert(next.rank == (rank + 1) % world_size &&
|
||||||
|
rank == (prev.rank + 1) % world_size,
|
||||||
|
"need to assume rank structure");
|
||||||
|
// send recv buffer
|
||||||
|
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||||
|
const size_t stop_read = total_size + slice_begin;
|
||||||
|
const size_t stop_write = total_size + slice_begin - size_prev_slice;
|
||||||
|
size_t write_ptr = slice_begin;
|
||||||
|
size_t read_ptr = slice_end;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
// select helper
|
||||||
|
bool finished = true;
|
||||||
|
utils::SelectHelper selecter;
|
||||||
|
if (read_ptr != stop_read) {
|
||||||
|
selecter.WatchRead(next.sock);
|
||||||
|
finished = false;
|
||||||
|
}
|
||||||
|
if (write_ptr != stop_write) {
|
||||||
|
if (write_ptr < read_ptr) {
|
||||||
|
selecter.WatchWrite(prev.sock);
|
||||||
|
}
|
||||||
|
finished = false;
|
||||||
|
}
|
||||||
|
if (finished) break;
|
||||||
|
selecter.Select();
|
||||||
|
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
|
||||||
|
size_t size = stop_read - read_ptr;
|
||||||
|
size_t start = read_ptr % total_size;
|
||||||
|
if (start + size > total_size) {
|
||||||
|
size = total_size - start;
|
||||||
|
}
|
||||||
|
ssize_t len = next.sock.Recv(sendrecvbuf + start, size);
|
||||||
|
if (len != -1) {
|
||||||
|
read_ptr += static_cast<size_t>(len);
|
||||||
|
} else {
|
||||||
|
ReturnType ret = Errno2Return(errno);
|
||||||
|
if (ret != kSuccess) return ReportError(&next, ret);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (write_ptr < read_ptr && write_ptr != stop_write) {
|
||||||
|
size_t size = std::min(read_ptr, stop_write) - write_ptr;
|
||||||
|
size_t start = write_ptr % total_size;
|
||||||
|
if (start + size > total_size) {
|
||||||
|
size = total_size - start;
|
||||||
|
}
|
||||||
|
ssize_t len = prev.sock.Send(sendrecvbuf + start, size);
|
||||||
|
if (len != -1) {
|
||||||
|
write_ptr += static_cast<size_t>(len);
|
||||||
|
} else {
|
||||||
|
ReturnType ret = Errno2Return(errno);
|
||||||
|
if (ret != kSuccess) return ReportError(&prev, ret);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail,
|
||||||
|
* and will return the cause of failure
|
||||||
|
*
|
||||||
|
* Ring-based algorithm
|
||||||
|
*
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param reducer reduce function
|
||||||
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
|
* \sa ReturnType, TryAllreduce
|
||||||
|
*/
|
||||||
|
AllreduceBase::ReturnType
|
||||||
|
AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
ReduceFunction reducer) {
|
||||||
|
// read from next link and send to prev one
|
||||||
|
LinkRecord &prev = *ring_prev, &next = *ring_next;
|
||||||
|
// need to reply on special rank structure
|
||||||
|
utils::Assert(next.rank == (rank + 1) % world_size &&
|
||||||
|
rank == (prev.rank + 1) % world_size,
|
||||||
|
"need to assume rank structure");
|
||||||
|
// total size of message
|
||||||
|
const size_t total_size = type_nbytes * count;
|
||||||
|
size_t n = static_cast<size_t>(world_size);
|
||||||
|
size_t step = (count + n - 1) / n;
|
||||||
|
size_t r = static_cast<size_t>(next.rank);
|
||||||
|
size_t write_ptr = std::min(r * step, count) * type_nbytes;
|
||||||
|
size_t read_ptr = std::min((r + 1) * step, count) * type_nbytes;
|
||||||
|
size_t reduce_ptr = read_ptr;
|
||||||
|
// send recv buffer
|
||||||
|
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||||
|
// position to stop reading
|
||||||
|
const size_t stop_read = total_size + write_ptr;
|
||||||
|
// position to stop writing
|
||||||
|
size_t stop_write = total_size + std::min(rank * step, count) * type_nbytes;
|
||||||
|
if (stop_write > stop_read) {
|
||||||
|
stop_write -= total_size;
|
||||||
|
utils::Assert(write_ptr <= stop_write, "write ptr boundary check");
|
||||||
|
}
|
||||||
|
// use ring buffer in next position
|
||||||
|
next.InitBuffer(type_nbytes, step, reduce_buffer_size);
|
||||||
|
// set size_read to read pointer for ring buffer to work properly
|
||||||
|
next.size_read = read_ptr;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
// select helper
|
||||||
|
bool finished = true;
|
||||||
|
utils::SelectHelper selecter;
|
||||||
|
if (read_ptr != stop_read) {
|
||||||
|
selecter.WatchRead(next.sock);
|
||||||
|
finished = false;
|
||||||
|
}
|
||||||
|
if (write_ptr != stop_write) {
|
||||||
|
if (write_ptr < reduce_ptr) {
|
||||||
|
selecter.WatchWrite(prev.sock);
|
||||||
|
}
|
||||||
|
finished = false;
|
||||||
|
}
|
||||||
|
if (finished) break;
|
||||||
|
selecter.Select();
|
||||||
|
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
|
||||||
|
ReturnType ret = next.ReadToRingBuffer(reduce_ptr);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
return ReportError(&next, ret);
|
||||||
|
}
|
||||||
|
// sync the rate
|
||||||
|
read_ptr = next.size_read;
|
||||||
|
utils::Assert(read_ptr <= stop_read, "read_ptr boundary check");
|
||||||
|
const size_t buffer_size = next.buffer_size;
|
||||||
|
size_t max_reduce = (read_ptr / type_nbytes) * type_nbytes;
|
||||||
|
while (reduce_ptr < max_reduce) {
|
||||||
|
size_t bstart = reduce_ptr % buffer_size;
|
||||||
|
size_t nread = std::min(buffer_size - bstart,
|
||||||
|
max_reduce - reduce_ptr);
|
||||||
|
size_t rstart = reduce_ptr % total_size;
|
||||||
|
nread = std::min(nread, total_size - rstart);
|
||||||
|
reducer(next.buffer_head + bstart,
|
||||||
|
sendrecvbuf + rstart,
|
||||||
|
static_cast<int>(nread / type_nbytes),
|
||||||
|
MPI::Datatype(type_nbytes));
|
||||||
|
reduce_ptr += nread;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (write_ptr < reduce_ptr && write_ptr != stop_write) {
|
||||||
|
size_t size = std::min(reduce_ptr, stop_write) - write_ptr;
|
||||||
|
size_t start = write_ptr % total_size;
|
||||||
|
if (start + size > total_size) {
|
||||||
|
size = total_size - start;
|
||||||
|
}
|
||||||
|
ssize_t len = prev.sock.Send(sendrecvbuf + start, size);
|
||||||
|
if (len != -1) {
|
||||||
|
write_ptr += static_cast<size_t>(len);
|
||||||
|
} else {
|
||||||
|
ReturnType ret = Errno2Return(errno);
|
||||||
|
if (ret != kSuccess) return ReportError(&prev, ret);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief perform in-place allreduce, on sendrecvbuf
|
||||||
|
* use a ring based algorithm
|
||||||
|
*
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param reducer reduce function
|
||||||
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
|
* \sa ReturnType
|
||||||
|
*/
|
||||||
|
AllreduceBase::ReturnType
|
||||||
|
AllreduceBase::TryAllreduceRing(void *sendrecvbuf_,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
ReduceFunction reducer) {
|
||||||
|
ReturnType ret = TryReduceScatterRing(sendrecvbuf_, type_nbytes, count, reducer);
|
||||||
|
if (ret != kSuccess) return ret;
|
||||||
|
size_t n = static_cast<size_t>(world_size);
|
||||||
|
size_t step = (count + n - 1) / n;
|
||||||
|
size_t begin = std::min(rank * step, count) * type_nbytes;
|
||||||
|
size_t end = std::min((rank + 1) * step, count) * type_nbytes;
|
||||||
|
// previous rank
|
||||||
|
int prank = ring_prev->rank;
|
||||||
|
// get rank of previous
|
||||||
|
return TryAllgatherRing
|
||||||
|
(sendrecvbuf_, type_nbytes * count,
|
||||||
|
begin, end,
|
||||||
|
(std::min((prank + 1) * step, count) -
|
||||||
|
std::min(prank * step, count)) * type_nbytes);
|
||||||
|
}
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
@ -380,13 +380,79 @@ class AllreduceBase : public IEngine {
|
|||||||
ReduceFunction reducer);
|
ReduceFunction reducer);
|
||||||
/*!
|
/*!
|
||||||
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
|
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
|
||||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||||
* \param size the size of the data to be broadcasted
|
* \param size the size of the data to be broadcasted
|
||||||
* \param root the root worker id to broadcast the data
|
* \param root the root worker id to broadcast the data
|
||||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
* \sa ReturnType
|
* \sa ReturnType
|
||||||
*/
|
*/
|
||||||
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
|
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
|
||||||
|
/*!
|
||||||
|
* \brief perform in-place allreduce, on sendrecvbuf,
|
||||||
|
* this function implements tree-shape reduction
|
||||||
|
*
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param reducer reduce function
|
||||||
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
|
* \sa ReturnType
|
||||||
|
*/
|
||||||
|
ReturnType TryAllreduceTree(void *sendrecvbuf_,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
ReduceFunction reducer);
|
||||||
|
/*!
|
||||||
|
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||||
|
* the data provided by current node k is [slice_begin, slice_end),
|
||||||
|
* the next node's segment must start with slice_end
|
||||||
|
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||||
|
* use a ring based algorithm
|
||||||
|
*
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||||
|
* \param total_size total size of data to be gathered
|
||||||
|
* \param slice_begin beginning of the current slice
|
||||||
|
* \param slice_end end of the current slice
|
||||||
|
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||||
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
|
* \sa ReturnType
|
||||||
|
*/
|
||||||
|
ReturnType TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
|
||||||
|
size_t slice_begin, size_t slice_end,
|
||||||
|
size_t size_prev_slice);
|
||||||
|
/*!
|
||||||
|
* \brief perform in-place allreduce, reduce on the sendrecvbuf,
|
||||||
|
*
|
||||||
|
* after the function, node k get k-th segment of the reduction result
|
||||||
|
* the k-th segment is defined by [k * step, min((k + 1) * step,count) )
|
||||||
|
* where step = ceil(count / world_size)
|
||||||
|
*
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param reducer reduce function
|
||||||
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
|
* \sa ReturnType, TryAllreduce
|
||||||
|
*/
|
||||||
|
ReturnType TryReduceScatterRing(void *sendrecvbuf_,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
ReduceFunction reducer);
|
||||||
|
/*!
|
||||||
|
* \brief perform in-place allreduce, on sendrecvbuf
|
||||||
|
* use a ring based algorithm, reduce-scatter + allgather
|
||||||
|
*
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param reducer reduce function
|
||||||
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
|
* \sa ReturnType
|
||||||
|
*/
|
||||||
|
ReturnType TryAllreduceRing(void *sendrecvbuf_,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
ReduceFunction reducer);
|
||||||
/*!
|
/*!
|
||||||
* \brief function used to report error when a link goes wrong
|
* \brief function used to report error when a link goes wrong
|
||||||
* \param link the pointer to the link who causes the error
|
* \param link the pointer to the link who causes the error
|
||||||
@ -432,6 +498,10 @@ class AllreduceBase : public IEngine {
|
|||||||
int slave_port, nport_trial;
|
int slave_port, nport_trial;
|
||||||
// reduce buffer size
|
// reduce buffer size
|
||||||
size_t reduce_buffer_size;
|
size_t reduce_buffer_size;
|
||||||
|
// reduction method
|
||||||
|
int reduce_method;
|
||||||
|
// mininum count of cells to use ring based method
|
||||||
|
size_t reduce_ring_mincount;
|
||||||
// current rank
|
// current rank
|
||||||
int rank;
|
int rank;
|
||||||
// world size
|
// world size
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
export CC = gcc
|
export CC = gcc
|
||||||
export CXX = g++
|
export CXX = g++
|
||||||
export MPICXX = mpicxx
|
export MPICXX = mpicxx
|
||||||
export LDFLAGS= -pthread -lm -lrt -L../lib
|
export LDFLAGS= -L../lib -pthread -lm -lrt
|
||||||
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../include -std=c++11
|
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../include -std=c++11
|
||||||
|
|
||||||
# specify tensor path
|
# specify tensor path
|
||||||
@ -29,7 +29,7 @@ local_recover: local_recover.o $(RABIT_OBJ)
|
|||||||
lazy_recover: lazy_recover.o $(RABIT_OBJ)
|
lazy_recover: lazy_recover.o $(RABIT_OBJ)
|
||||||
|
|
||||||
$(BIN) :
|
$(BIN) :
|
||||||
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit_mock
|
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) -lrabit_mock $(LDFLAGS)
|
||||||
|
|
||||||
$(OBJ) :
|
$(OBJ) :
|
||||||
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
||||||
|
|||||||
@ -23,4 +23,7 @@ lazy_recover_10_10k_die_hard:
|
|||||||
../tracker/rabit_demo.py -n 10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0
|
../tracker/rabit_demo.py -n 10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0
|
||||||
|
|
||||||
lazy_recover_10_10k_die_same:
|
lazy_recover_10_10k_die_same:
|
||||||
../tracker/rabit_demo.py -n 10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0
|
../tracker/rabit_demo.py -n 10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0
|
||||||
|
|
||||||
|
ringallreduce_10_10k:
|
||||||
|
../tracker/rabit_demo.py -v 1 -n 10 model_recover 100 rabit_reduce_ring_mincount=10
|
||||||
|
|||||||
@ -188,6 +188,7 @@ class Tracker:
|
|||||||
vlst.reverse()
|
vlst.reverse()
|
||||||
rlst += vlst
|
rlst += vlst
|
||||||
return rlst
|
return rlst
|
||||||
|
|
||||||
def get_ring(self, tree_map, parent_map):
|
def get_ring(self, tree_map, parent_map):
|
||||||
"""
|
"""
|
||||||
get a ring connection used to recover local data
|
get a ring connection used to recover local data
|
||||||
@ -202,14 +203,44 @@ class Tracker:
|
|||||||
rnext = (r + 1) % nslave
|
rnext = (r + 1) % nslave
|
||||||
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
|
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
|
||||||
return ring_map
|
return ring_map
|
||||||
|
|
||||||
|
def get_link_map(self, nslave):
|
||||||
|
"""
|
||||||
|
get the link map, this is a bit hacky, call for better algorithm
|
||||||
|
to place similar nodes together
|
||||||
|
"""
|
||||||
|
tree_map, parent_map = self.get_tree(nslave)
|
||||||
|
ring_map = self.get_ring(tree_map, parent_map)
|
||||||
|
rmap = {0 : 0}
|
||||||
|
k = 0
|
||||||
|
for i in range(nslave - 1):
|
||||||
|
k = ring_map[k][1]
|
||||||
|
rmap[k] = i + 1
|
||||||
|
|
||||||
|
ring_map_ = {}
|
||||||
|
tree_map_ = {}
|
||||||
|
parent_map_ ={}
|
||||||
|
for k, v in ring_map.items():
|
||||||
|
ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]])
|
||||||
|
for k, v in tree_map.items():
|
||||||
|
tree_map_[rmap[k]] = [rmap[x] for x in v]
|
||||||
|
for k, v in parent_map.items():
|
||||||
|
if k != 0:
|
||||||
|
parent_map_[rmap[k]] = rmap[v]
|
||||||
|
else:
|
||||||
|
parent_map_[rmap[k]] = -1
|
||||||
|
return tree_map_, parent_map_, ring_map_
|
||||||
|
|
||||||
def handle_print(self,slave, msg):
|
def handle_print(self,slave, msg):
|
||||||
sys.stdout.write(msg)
|
sys.stdout.write(msg)
|
||||||
|
|
||||||
def log_print(self, msg, level):
|
def log_print(self, msg, level):
|
||||||
if level == 1:
|
if level == 1:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
sys.stderr.write(msg + '\n')
|
sys.stderr.write(msg + '\n')
|
||||||
else:
|
else:
|
||||||
sys.stderr.write(msg + '\n')
|
sys.stderr.write(msg + '\n')
|
||||||
|
|
||||||
def accept_slaves(self, nslave):
|
def accept_slaves(self, nslave):
|
||||||
# set of nodes that finishs the job
|
# set of nodes that finishs the job
|
||||||
shutdown = {}
|
shutdown = {}
|
||||||
@ -241,30 +272,36 @@ class Tracker:
|
|||||||
assert s.cmd == 'start'
|
assert s.cmd == 'start'
|
||||||
if s.world_size > 0:
|
if s.world_size > 0:
|
||||||
nslave = s.world_size
|
nslave = s.world_size
|
||||||
tree_map, parent_map = self.get_tree(nslave)
|
tree_map, parent_map, ring_map = self.get_link_map(nslave)
|
||||||
ring_map = self.get_ring(tree_map, parent_map)
|
|
||||||
# set of nodes that is pending for getting up
|
# set of nodes that is pending for getting up
|
||||||
todo_nodes = range(nslave)
|
todo_nodes = range(nslave)
|
||||||
random.shuffle(todo_nodes)
|
|
||||||
else:
|
else:
|
||||||
assert s.world_size == -1 or s.world_size == nslave
|
assert s.world_size == -1 or s.world_size == nslave
|
||||||
if s.cmd == 'recover':
|
if s.cmd == 'recover':
|
||||||
assert s.rank >= 0
|
assert s.rank >= 0
|
||||||
|
|
||||||
rank = s.decide_rank(job_map)
|
rank = s.decide_rank(job_map)
|
||||||
|
# batch assignment of ranks
|
||||||
if rank == -1:
|
if rank == -1:
|
||||||
assert len(todo_nodes) != 0
|
assert len(todo_nodes) != 0
|
||||||
rank = todo_nodes.pop(0)
|
pending.append(s)
|
||||||
if s.jobid != 'NULL':
|
if len(pending) == len(todo_nodes):
|
||||||
job_map[s.jobid] = rank
|
pending.sort(key = lambda x : x.host)
|
||||||
|
for s in pending:
|
||||||
|
rank = todo_nodes.pop(0)
|
||||||
|
if s.jobid != 'NULL':
|
||||||
|
job_map[s.jobid] = rank
|
||||||
|
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||||
|
if s.wait_accept > 0:
|
||||||
|
wait_conn[rank] = s
|
||||||
|
self.log_print('Recieve %s signal from %s; assign rank %d' % (s.cmd, s.host, s.rank), 1)
|
||||||
if len(todo_nodes) == 0:
|
if len(todo_nodes) == 0:
|
||||||
self.log_print('@tracker All of %d nodes getting started' % nslave, 2)
|
self.log_print('@tracker All of %d nodes getting started' % nslave, 2)
|
||||||
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
|
||||||
if s.cmd != 'start':
|
|
||||||
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank), 1)
|
|
||||||
else:
|
else:
|
||||||
self.log_print('Recieve %s signal from %s; assign rank %d' % (s.cmd, s.host, s.rank), 1)
|
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||||
if s.wait_accept > 0:
|
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank), 1)
|
||||||
wait_conn[rank] = s
|
if s.wait_accept > 0:
|
||||||
|
wait_conn[rank] = s
|
||||||
self.log_print('@tracker All nodes finishes job', 2)
|
self.log_print('@tracker All nodes finishes job', 2)
|
||||||
|
|
||||||
def submit(nslave, args, fun_submit, verbose, hostIP = 'auto'):
|
def submit(nslave, args, fun_submit, verbose, hostIP = 'auto'):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user