From c57dad8b1719a88bf300cc6a216f6b0d359f4f6d Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 11 Mar 2015 12:00:19 -0700 Subject: [PATCH] add ringbased passing and batch schedule --- rabit-learn/io/file-inl.h | 15 +- rabit-learn/io/io-inl.h | 2 + rabit-learn/io/line_split-inl.h | 32 ++-- src/allreduce_base.cc | 279 ++++++++++++++++++++++++++++++-- src/allreduce_base.h | 74 ++++++++- test/Makefile | 4 +- test/test.mk | 5 +- tracker/rabit_tracker.py | 63 ++++++-- 8 files changed, 418 insertions(+), 56 deletions(-) diff --git a/rabit-learn/io/file-inl.h b/rabit-learn/io/file-inl.h index d77a943de..8e9d9c593 100644 --- a/rabit-learn/io/file-inl.h +++ b/rabit-learn/io/file-inl.h @@ -19,6 +19,7 @@ class FileStream : public utils::ISeekStream { public: explicit FileStream(const char *fname, const char *mode) : use_stdio(false) { + using namespace std; #ifndef RABIT_STRICT_CXX98_ if (!strcmp(fname, "stdin")) { use_stdio = true; fp = stdin; @@ -51,7 +52,7 @@ class FileStream : public utils::ISeekStream { return std::ftell(fp); } virtual bool AtEnd(void) const { - return feof(fp) != 0; + return std::feof(fp) != 0; } inline void Close(void) { if (fp != NULL && !use_stdio) { @@ -60,7 +61,7 @@ class FileStream : public utils::ISeekStream { } private: - FILE *fp; + std::FILE *fp; bool use_stdio; }; @@ -71,7 +72,7 @@ class FileSplit : public LineSplitBase { LineSplitBase::SplitNames(&fnames_, uri, "#"); std::vector fsize; for (size_t i = 0; i < fnames_.size(); ++i) { - if (!strncmp(fnames_[i].c_str(), "file://", 7)) { + if (!std::strncmp(fnames_[i].c_str(), "file://", 7)) { std::string tmp = fnames_[i].c_str() + 7; fnames_[i] = tmp; } @@ -88,11 +89,11 @@ class FileSplit : public LineSplitBase { } // get file size inline static size_t GetFileSize(const char *fname) { - FILE *fp = utils::FopenCheck(fname, "rb"); + std::FILE *fp = utils::FopenCheck(fname, "rb"); // NOTE: fseek may not be good, but serves as ok solution - fseek(fp, 0, SEEK_END); - size_t fsize = static_cast(ftell(fp)); - fclose(fp); + std::fseek(fp, 0, SEEK_END); + size_t fsize = static_cast(std::ftell(fp)); + std::fclose(fp); return fsize; } diff --git a/rabit-learn/io/io-inl.h b/rabit-learn/io/io-inl.h index b09e9da7c..db599cedc 100644 --- a/rabit-learn/io/io-inl.h +++ b/rabit-learn/io/io-inl.h @@ -25,6 +25,7 @@ namespace io { inline InputSplit *CreateInputSplit(const char *uri, unsigned part, unsigned nsplit) { + using namespace std; if (!strcmp(uri, "stdin")) { return new SingleFileSplit(uri); } @@ -48,6 +49,7 @@ inline InputSplit *CreateInputSplit(const char *uri, * \param mode can be 'w' or 'r' for read or write */ inline IStream *CreateStream(const char *uri, const char *mode) { + using namespace std; if (!strncmp(uri, "file://", 7)) { return new FileStream(uri + 7, mode); } diff --git a/rabit-learn/io/line_split-inl.h b/rabit-learn/io/line_split-inl.h index 7ef322137..ba21d39e4 100644 --- a/rabit-learn/io/line_split-inl.h +++ b/rabit-learn/io/line_split-inl.h @@ -1,7 +1,7 @@ #ifndef RABIT_LEARN_IO_LINE_SPLIT_INL_H_ #define RABIT_LEARN_IO_LINE_SPLIT_INL_H_ /*! - * \file line_split-inl.h + * \std::FILE line_split-inl.h * \brief base implementation of line-spliter * \author Tianqi Chen */ @@ -30,7 +30,7 @@ class LineSplitBase : public InputSplit { if (out_data->length() != 0) return true; file_ptr_ += 1; if (offset_curr_ != file_offset_[file_ptr_]) { - utils::Error("warning:file size not calculated correctly\n"); + utils::Error("warning:std::FILE size not calculated correctly\n"); offset_curr_ = file_offset_[file_ptr_]; } if (offset_curr_ >= offset_end_) return false; @@ -59,7 +59,7 @@ class LineSplitBase : public InputSplit { } /*! * \brief initialize the line spliter, - * \param file_size, size of each files + * \param file_size, size of each std::FILEs * \param rank the current rank of the data * \param nsplit number of split we will divide the data into */ @@ -96,31 +96,31 @@ class LineSplitBase : public InputSplit { } /*! * \brief get the seek stream of given file_index - * \return the corresponding seek stream at head of file + * \return the corresponding seek stream at head of std::FILE */ virtual utils::ISeekStream *GetFile(size_t file_index) = 0; /*! * \brief split names given - * \param out_fname output file names - * \param uri_ the iput uri file + * \param out_fname output std::FILE names + * \param uri_ the iput uri std::FILE * \param dlm deliminetr */ inline static void SplitNames(std::vector *out_fname, const char *uri_, const char *dlm) { std::string uri = uri_; - char *p = strtok(BeginPtr(uri), dlm); + char *p = std::strtok(BeginPtr(uri), dlm); while (p != NULL) { out_fname->push_back(std::string(p)); - p = strtok(NULL, dlm); + p = std::strtok(NULL, dlm); } } private: /*! \brief current input stream */ utils::ISeekStream *fs_; - /*! \brief file pointer of which file to read on */ + /*! \brief std::FILE pointer of which std::FILE to read on */ size_t file_ptr_; - /*! \brief file pointer where the end of file lies */ + /*! \brief std::FILE pointer where the end of std::FILE lies */ size_t file_ptr_end_; /*! \brief get the current offset */ size_t offset_curr_; @@ -128,7 +128,7 @@ class LineSplitBase : public InputSplit { size_t offset_begin_; /*! \brief end of the offset */ size_t offset_end_; - /*! \brief byte-offset of each file */ + /*! \brief byte-offset of each std::FILE */ std::vector file_offset_; /*! \brief buffer reader */ StreamBufferReader reader_; @@ -136,11 +136,11 @@ class LineSplitBase : public InputSplit { const static size_t kBufferSize = 256; }; -/*! \brief line split from single file */ +/*! \brief line split from single std::FILE */ class SingleFileSplit : public InputSplit { public: explicit SingleFileSplit(const char *fname) { - if (!strcmp(fname, "stdin")) { + if (!std::strcmp(fname, "stdin")) { #ifndef RABIT_STRICT_CXX98_ use_stdin_ = true; fp_ = stdin; #endif @@ -151,13 +151,13 @@ class SingleFileSplit : public InputSplit { end_of_file_ = false; } virtual ~SingleFileSplit(void) { - if (!use_stdin_) fclose(fp_); + if (!use_stdin_) std::fclose(fp_); } virtual bool NextLine(std::string *out_data) { if (end_of_file_) return false; out_data->clear(); while (true) { - char c = fgetc(fp_); + char c = std::fgetc(fp_); if (c == EOF) { end_of_file_ = true; } @@ -172,7 +172,7 @@ class SingleFileSplit : public InputSplit { } private: - FILE *fp_; + std::FILE *fp_; bool use_stdin_; bool end_of_file_; }; diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 831722dc5..841525bc2 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -26,6 +26,9 @@ AllreduceBase::AllreduceBase(void) { world_size = -1; hadoop_mode = 0; version_number = 0; + // 32 K items + reduce_ring_mincount = 32 << 10; + // tracker URL task_id = "NULL"; err_link = NULL; this->SetParam("rabit_reduce_buffer", "256MB"); @@ -33,7 +36,8 @@ AllreduceBase::AllreduceBase(void) { env_vars.push_back("rabit_task_id"); env_vars.push_back("rabit_num_trial"); env_vars.push_back("rabit_reduce_buffer"); - env_vars.push_back("rabit_tracker_uri"); + env_vars.push_back("rabit_reduce_ring_mincount"); + env_vars.push_back("rabit_tracker_uri"); env_vars.push_back("rabit_tracker_port"); } @@ -116,6 +120,27 @@ void AllreduceBase::TrackerPrint(const std::string &msg) { tracker.SendStr(msg); tracker.Close(); } +// util to parse data with unit suffix +inline size_t ParseUnit(const char *name, const char *val) { + char unit; + uint64_t amount; + int n = sscanf(val, "%lu%c", &amount, &unit); + if (n == 2) { + switch (unit) { + case 'B': return amount; + case 'K': return amount << 10UL; + case 'M': return amount << 20UL; + case 'G': return amount << 30UL; + default: utils::Error("invalid format for %s", name); return 0; + } + } else if (n == 1) { + return amount; + } else { + utils::Error("invalid format for %s," \ + "shhould be {integer}{unit}, unit can be {B, KB, MB, GB}", name); + return 0; + } +} /*! * \brief set parameters to the engine * \param name parameter name @@ -127,21 +152,11 @@ void AllreduceBase::SetParam(const char *name, const char *val) { if (!strcmp(name, "rabit_task_id")) task_id = val; if (!strcmp(name, "rabit_world_size")) world_size = atoi(val); if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val); + if (!strcmp(name, "rabit_reduce_ring_mincount")) { + reduce_ring_mincount = ParseUnit(name, val); + } if (!strcmp(name, "rabit_reduce_buffer")) { - char unit; - uint64_t amount; - if (sscanf(val, "%lu%c", &amount, &unit) == 2) { - switch (unit) { - case 'B': reduce_buffer_size = (amount + 7)/ 8; break; - case 'K': reduce_buffer_size = amount << 7UL; break; - case 'M': reduce_buffer_size = amount << 17UL; break; - case 'G': reduce_buffer_size = amount << 27UL; break; - default: utils::Error("invalid format for reduce buffer"); - } - } else { - utils::Error("invalid format for reduce_buffer,"\ - "shhould be {integer}{unit}, unit can be {B, KB, MB, GB}"); - } + reduce_buffer_size = (ParseUnit(name, val) + 7) >> 3; } } /*! @@ -341,6 +356,28 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, ReduceFunction reducer) { + if (count > reduce_ring_mincount) { + return this->TryAllreduceRing(sendrecvbuf_, type_nbytes, count, reducer); + } else { + return this->TryAllreduceTree(sendrecvbuf_, type_nbytes, count, reducer); + } +} +/*! + * \brief perform in-place allreduce, on sendrecvbuf, + * this function implements tree-shape reduction + * + * \param sendrecvbuf_ buffer for both sending and recving data + * \param type_nbytes the unit number of bytes the type have + * \param count number of elements to be reduced + * \param reducer reduce function + * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details + * \sa ReturnType + */ +AllreduceBase::ReturnType +AllreduceBase::TryAllreduceTree(void *sendrecvbuf_, + size_t type_nbytes, + size_t count, + ReduceFunction reducer) { RefLinkVector &links = tree_links; if (links.size() == 0 || count == 0) return kSuccess; // total size of message @@ -599,5 +636,217 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { } return kSuccess; } +/*! + * \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf, + * the data provided by current node k is [slice_begin, slice_end), + * the next node's segment must start with slice_end + * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments + * use a ring based algorithm + * + * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually + * \param total_size total size of data to be gathered + * \param slice_begin beginning of the current slice + * \param slice_end end of the current slice + * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size + */ +AllreduceBase::ReturnType +AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size, + size_t slice_begin, + size_t slice_end, + size_t size_prev_slice) { + // read from next link and send to prev one + LinkRecord &prev = *ring_prev, &next = *ring_next; + // need to reply on special rank structure + utils::Assert(next.rank == (rank + 1) % world_size && + rank == (prev.rank + 1) % world_size, + "need to assume rank structure"); + // send recv buffer + char *sendrecvbuf = reinterpret_cast(sendrecvbuf_); + const size_t stop_read = total_size + slice_begin; + const size_t stop_write = total_size + slice_begin - size_prev_slice; + size_t write_ptr = slice_begin; + size_t read_ptr = slice_end; + + while (true) { + // select helper + bool finished = true; + utils::SelectHelper selecter; + if (read_ptr != stop_read) { + selecter.WatchRead(next.sock); + finished = false; + } + if (write_ptr != stop_write) { + if (write_ptr < read_ptr) { + selecter.WatchWrite(prev.sock); + } + finished = false; + } + if (finished) break; + selecter.Select(); + if (read_ptr != stop_read && selecter.CheckRead(next.sock)) { + size_t size = stop_read - read_ptr; + size_t start = read_ptr % total_size; + if (start + size > total_size) { + size = total_size - start; + } + ssize_t len = next.sock.Recv(sendrecvbuf + start, size); + if (len != -1) { + read_ptr += static_cast(len); + } else { + ReturnType ret = Errno2Return(errno); + if (ret != kSuccess) return ReportError(&next, ret); + } + } + if (write_ptr < read_ptr && write_ptr != stop_write) { + size_t size = std::min(read_ptr, stop_write) - write_ptr; + size_t start = write_ptr % total_size; + if (start + size > total_size) { + size = total_size - start; + } + ssize_t len = prev.sock.Send(sendrecvbuf + start, size); + if (len != -1) { + write_ptr += static_cast(len); + } else { + ReturnType ret = Errno2Return(errno); + if (ret != kSuccess) return ReportError(&prev, ret); + } + } + } + return kSuccess; +} +/*! + * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, + * and will return the cause of failure + * + * Ring-based algorithm + * + * \param sendrecvbuf_ buffer for both sending and recving data + * \param type_nbytes the unit number of bytes the type have + * \param count number of elements to be reduced + * \param reducer reduce function + * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details + * \sa ReturnType, TryAllreduce + */ +AllreduceBase::ReturnType +AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_, + size_t type_nbytes, + size_t count, + ReduceFunction reducer) { + // read from next link and send to prev one + LinkRecord &prev = *ring_prev, &next = *ring_next; + // need to reply on special rank structure + utils::Assert(next.rank == (rank + 1) % world_size && + rank == (prev.rank + 1) % world_size, + "need to assume rank structure"); + // total size of message + const size_t total_size = type_nbytes * count; + size_t n = static_cast(world_size); + size_t step = (count + n - 1) / n; + size_t r = static_cast(next.rank); + size_t write_ptr = std::min(r * step, count) * type_nbytes; + size_t read_ptr = std::min((r + 1) * step, count) * type_nbytes; + size_t reduce_ptr = read_ptr; + // send recv buffer + char *sendrecvbuf = reinterpret_cast(sendrecvbuf_); + // position to stop reading + const size_t stop_read = total_size + write_ptr; + // position to stop writing + size_t stop_write = total_size + std::min(rank * step, count) * type_nbytes; + if (stop_write > stop_read) { + stop_write -= total_size; + utils::Assert(write_ptr <= stop_write, "write ptr boundary check"); + } + // use ring buffer in next position + next.InitBuffer(type_nbytes, step, reduce_buffer_size); + // set size_read to read pointer for ring buffer to work properly + next.size_read = read_ptr; + + while (true) { + // select helper + bool finished = true; + utils::SelectHelper selecter; + if (read_ptr != stop_read) { + selecter.WatchRead(next.sock); + finished = false; + } + if (write_ptr != stop_write) { + if (write_ptr < reduce_ptr) { + selecter.WatchWrite(prev.sock); + } + finished = false; + } + if (finished) break; + selecter.Select(); + if (read_ptr != stop_read && selecter.CheckRead(next.sock)) { + ReturnType ret = next.ReadToRingBuffer(reduce_ptr); + if (ret != kSuccess) { + return ReportError(&next, ret); + } + // sync the rate + read_ptr = next.size_read; + utils::Assert(read_ptr <= stop_read, "read_ptr boundary check"); + const size_t buffer_size = next.buffer_size; + size_t max_reduce = (read_ptr / type_nbytes) * type_nbytes; + while (reduce_ptr < max_reduce) { + size_t bstart = reduce_ptr % buffer_size; + size_t nread = std::min(buffer_size - bstart, + max_reduce - reduce_ptr); + size_t rstart = reduce_ptr % total_size; + nread = std::min(nread, total_size - rstart); + reducer(next.buffer_head + bstart, + sendrecvbuf + rstart, + static_cast(nread / type_nbytes), + MPI::Datatype(type_nbytes)); + reduce_ptr += nread; + } + } + if (write_ptr < reduce_ptr && write_ptr != stop_write) { + size_t size = std::min(reduce_ptr, stop_write) - write_ptr; + size_t start = write_ptr % total_size; + if (start + size > total_size) { + size = total_size - start; + } + ssize_t len = prev.sock.Send(sendrecvbuf + start, size); + if (len != -1) { + write_ptr += static_cast(len); + } else { + ReturnType ret = Errno2Return(errno); + if (ret != kSuccess) return ReportError(&prev, ret); + } + } + } + return kSuccess; +} +/*! + * \brief perform in-place allreduce, on sendrecvbuf + * use a ring based algorithm + * + * \param sendrecvbuf_ buffer for both sending and recving data + * \param type_nbytes the unit number of bytes the type have + * \param count number of elements to be reduced + * \param reducer reduce function + * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details + * \sa ReturnType + */ +AllreduceBase::ReturnType +AllreduceBase::TryAllreduceRing(void *sendrecvbuf_, + size_t type_nbytes, + size_t count, + ReduceFunction reducer) { + ReturnType ret = TryReduceScatterRing(sendrecvbuf_, type_nbytes, count, reducer); + if (ret != kSuccess) return ret; + size_t n = static_cast(world_size); + size_t step = (count + n - 1) / n; + size_t begin = std::min(rank * step, count) * type_nbytes; + size_t end = std::min((rank + 1) * step, count) * type_nbytes; + // previous rank + int prank = ring_prev->rank; + // get rank of previous + return TryAllgatherRing + (sendrecvbuf_, type_nbytes * count, + begin, end, + (std::min((prank + 1) * step, count) - + std::min(prank * step, count)) * type_nbytes); +} } // namespace engine } // namespace rabit diff --git a/src/allreduce_base.h b/src/allreduce_base.h index 2d75da684..af4c7cfdc 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -380,13 +380,79 @@ class AllreduceBase : public IEngine { ReduceFunction reducer); /*! * \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure - * \param sendrecvbuf_ buffer for both sending and recving data + * \param sendrecvbuf_ buffer for both sending and receiving data * \param size the size of the data to be broadcasted * \param root the root worker id to broadcast the data * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details * \sa ReturnType */ - ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root); + ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root); + /*! + * \brief perform in-place allreduce, on sendrecvbuf, + * this function implements tree-shape reduction + * + * \param sendrecvbuf_ buffer for both sending and recving data + * \param type_nbytes the unit number of bytes the type have + * \param count number of elements to be reduced + * \param reducer reduce function + * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details + * \sa ReturnType + */ + ReturnType TryAllreduceTree(void *sendrecvbuf_, + size_t type_nbytes, + size_t count, + ReduceFunction reducer); + /*! + * \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf, + * the data provided by current node k is [slice_begin, slice_end), + * the next node's segment must start with slice_end + * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments + * use a ring based algorithm + * + * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually + * \param total_size total size of data to be gathered + * \param slice_begin beginning of the current slice + * \param slice_end end of the current slice + * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size + * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details + * \sa ReturnType + */ + ReturnType TryAllgatherRing(void *sendrecvbuf_, size_t total_size, + size_t slice_begin, size_t slice_end, + size_t size_prev_slice); + /*! + * \brief perform in-place allreduce, reduce on the sendrecvbuf, + * + * after the function, node k get k-th segment of the reduction result + * the k-th segment is defined by [k * step, min((k + 1) * step,count) ) + * where step = ceil(count / world_size) + * + * \param sendrecvbuf_ buffer for both sending and recving data + * \param type_nbytes the unit number of bytes the type have + * \param count number of elements to be reduced + * \param reducer reduce function + * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details + * \sa ReturnType, TryAllreduce + */ + ReturnType TryReduceScatterRing(void *sendrecvbuf_, + size_t type_nbytes, + size_t count, + ReduceFunction reducer); + /*! + * \brief perform in-place allreduce, on sendrecvbuf + * use a ring based algorithm, reduce-scatter + allgather + * + * \param sendrecvbuf_ buffer for both sending and recving data + * \param type_nbytes the unit number of bytes the type have + * \param count number of elements to be reduced + * \param reducer reduce function + * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details + * \sa ReturnType + */ + ReturnType TryAllreduceRing(void *sendrecvbuf_, + size_t type_nbytes, + size_t count, + ReduceFunction reducer); /*! * \brief function used to report error when a link goes wrong * \param link the pointer to the link who causes the error @@ -432,6 +498,10 @@ class AllreduceBase : public IEngine { int slave_port, nport_trial; // reduce buffer size size_t reduce_buffer_size; + // reduction method + int reduce_method; + // mininum count of cells to use ring based method + size_t reduce_ring_mincount; // current rank int rank; // world size diff --git a/test/Makefile b/test/Makefile index 5ff983e81..a1ff6a854 100644 --- a/test/Makefile +++ b/test/Makefile @@ -1,7 +1,7 @@ export CC = gcc export CXX = g++ export MPICXX = mpicxx -export LDFLAGS= -pthread -lm -lrt -L../lib +export LDFLAGS= -L../lib -pthread -lm -lrt export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../include -std=c++11 # specify tensor path @@ -29,7 +29,7 @@ local_recover: local_recover.o $(RABIT_OBJ) lazy_recover: lazy_recover.o $(RABIT_OBJ) $(BIN) : - $(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit_mock + $(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) -lrabit_mock $(LDFLAGS) $(OBJ) : $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) ) diff --git a/test/test.mk b/test/test.mk index 360bc6cfe..be3429bab 100644 --- a/test/test.mk +++ b/test/test.mk @@ -23,4 +23,7 @@ lazy_recover_10_10k_die_hard: ../tracker/rabit_demo.py -n 10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0 lazy_recover_10_10k_die_same: - ../tracker/rabit_demo.py -n 10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 \ No newline at end of file + ../tracker/rabit_demo.py -n 10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 + +ringallreduce_10_10k: + ../tracker/rabit_demo.py -v 1 -n 10 model_recover 100 rabit_reduce_ring_mincount=10 diff --git a/tracker/rabit_tracker.py b/tracker/rabit_tracker.py index 61ad32497..62ecd92d5 100644 --- a/tracker/rabit_tracker.py +++ b/tracker/rabit_tracker.py @@ -188,6 +188,7 @@ class Tracker: vlst.reverse() rlst += vlst return rlst + def get_ring(self, tree_map, parent_map): """ get a ring connection used to recover local data @@ -202,14 +203,44 @@ class Tracker: rnext = (r + 1) % nslave ring_map[rlst[r]] = (rlst[rprev], rlst[rnext]) return ring_map + + def get_link_map(self, nslave): + """ + get the link map, this is a bit hacky, call for better algorithm + to place similar nodes together + """ + tree_map, parent_map = self.get_tree(nslave) + ring_map = self.get_ring(tree_map, parent_map) + rmap = {0 : 0} + k = 0 + for i in range(nslave - 1): + k = ring_map[k][1] + rmap[k] = i + 1 + + ring_map_ = {} + tree_map_ = {} + parent_map_ ={} + for k, v in ring_map.items(): + ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]]) + for k, v in tree_map.items(): + tree_map_[rmap[k]] = [rmap[x] for x in v] + for k, v in parent_map.items(): + if k != 0: + parent_map_[rmap[k]] = rmap[v] + else: + parent_map_[rmap[k]] = -1 + return tree_map_, parent_map_, ring_map_ + def handle_print(self,slave, msg): sys.stdout.write(msg) + def log_print(self, msg, level): if level == 1: if self.verbose: sys.stderr.write(msg + '\n') else: sys.stderr.write(msg + '\n') + def accept_slaves(self, nslave): # set of nodes that finishs the job shutdown = {} @@ -241,30 +272,36 @@ class Tracker: assert s.cmd == 'start' if s.world_size > 0: nslave = s.world_size - tree_map, parent_map = self.get_tree(nslave) - ring_map = self.get_ring(tree_map, parent_map) + tree_map, parent_map, ring_map = self.get_link_map(nslave) # set of nodes that is pending for getting up todo_nodes = range(nslave) - random.shuffle(todo_nodes) else: assert s.world_size == -1 or s.world_size == nslave if s.cmd == 'recover': assert s.rank >= 0 + rank = s.decide_rank(job_map) + # batch assignment of ranks if rank == -1: assert len(todo_nodes) != 0 - rank = todo_nodes.pop(0) - if s.jobid != 'NULL': - job_map[s.jobid] = rank + pending.append(s) + if len(pending) == len(todo_nodes): + pending.sort(key = lambda x : x.host) + for s in pending: + rank = todo_nodes.pop(0) + if s.jobid != 'NULL': + job_map[s.jobid] = rank + s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) + if s.wait_accept > 0: + wait_conn[rank] = s + self.log_print('Recieve %s signal from %s; assign rank %d' % (s.cmd, s.host, s.rank), 1) if len(todo_nodes) == 0: - self.log_print('@tracker All of %d nodes getting started' % nslave, 2) - s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) - if s.cmd != 'start': - self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank), 1) + self.log_print('@tracker All of %d nodes getting started' % nslave, 2) else: - self.log_print('Recieve %s signal from %s; assign rank %d' % (s.cmd, s.host, s.rank), 1) - if s.wait_accept > 0: - wait_conn[rank] = s + s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) + self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank), 1) + if s.wait_accept > 0: + wait_conn[rank] = s self.log_print('@tracker All nodes finishes job', 2) def submit(nslave, args, fun_submit, verbose, hostIP = 'auto'):