diff --git a/include/rabit/c_api.h b/include/rabit/c_api.h index 87c3d40a4..0a96ef7d3 100644 --- a/include/rabit/c_api.h +++ b/include/rabit/c_api.h @@ -43,6 +43,12 @@ RABIT_DLL bool RabitInit(int argc, char *argv[]); */ RABIT_DLL bool RabitFinalize(void); +/*! + * \brief get rank of previous process in ring topology + * \return rank number of worker + * */ +RABIT_DLL int RabitGetRingPrevRank(void); + /*! * \brief get rank of current process * \return rank number of worker @@ -87,6 +93,30 @@ RABIT_DLL void RabitGetProcessorName(char *out_name, */ RABIT_DLL void RabitBroadcast(void *sendrecv_data, rbt_ulong size, int root); + +/*! + * \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf, + * the data provided by current node k is [slice_begin, slice_end), + * the next node's segment must start with slice_end + * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments + * use a ring based algorithm + * + * \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually + * \param total_size total size of data to be gathered + * \param beginIndex beginning of the current slice in sendrecvbuf of type enum_dtype + * \param size_node_slice size of the current node slice + * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size + * \param enum_dtype the enumeration of data type, see rabit::engine::mpi::DataType in engine.h of rabit include + * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details + * \sa ReturnType + */ +RABIT_DLL void RabitAllgather(void *sendrecvbuf, + size_t total_size, + size_t beginIndex, + size_t size_node_slice, + size_t size_prev_slice, + int enum_dtype); + /*! * \brief perform in-place allreduce, on sendrecvbuf * this function is NOT thread-safe diff --git a/include/rabit/internal/engine.h b/include/rabit/internal/engine.h index 6e4d306b4..98bfe7fb5 100644 --- a/include/rabit/internal/engine.h +++ b/include/rabit/internal/engine.h @@ -53,6 +53,29 @@ class IEngine { const MPI::Datatype &dtype); /*! \brief virtual destructor */ virtual ~IEngine() {} + /*! + * \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf, + * the data provided by current node k is [slice_begin, slice_end), + * the next node's segment must start with slice_end + * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments + * use a ring based algorithm + * + * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually + * \param total_size total size of data to be gathered + * \param slice_begin beginning of the current slice + * \param slice_end end of the current slice + * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size + * \param _file caller file name used to generate unique cache key + * \param _line caller line number used to generate unique cache key + * \param _caller caller function name used to generate unique cache key + */ + virtual void Allgather(void *sendrecvbuf, size_t total_size, + size_t slice_begin, + size_t slice_end, + size_t size_prev_slice, + const char* _file = _FILE, + const int _line = _LINE, + const char* _caller = _CALLER) = 0; /*! * \brief performs in-place Allreduce, on sendrecvbuf * this function is NOT thread-safe @@ -165,6 +188,8 @@ class IEngine { * \sa LoadCheckPoint, CheckPoint */ virtual int VersionNumber(void) const = 0; + /*! \brief gets rank of previous node in ring topology */ + virtual int GetRingPrevRank(void) const = 0; /*! \brief gets rank of current node */ virtual int GetRank(void) const = 0; /*! \brief gets total number of nodes */ @@ -212,6 +237,26 @@ enum DataType { kULongLong = 9 }; } // namespace mpi +/*! + * \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf, + * the data provided by current node k is [slice_begin, slice_end), + * the next node's segment must start with slice_end + * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments + * use a ring based algorithm + * + * \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually + * \param total_size total size of data to be gathered + * \param slice_begin beginning of the current slice + * \param slice_end end of the current slice + * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size + * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details + * \sa ReturnType + */ +void Allgather(void* sendrecvbuf, + size_t total_size, + size_t slice_begin, + size_t slice_end, + size_t size_prev_slice); /*! * \brief perform in-place Allreduce, on sendrecvbuf * this is an internal function used by rabit to be able to compile with MPI diff --git a/include/rabit/internal/rabit-inl.h b/include/rabit/internal/rabit-inl.h index 7a3f9de75..f56b37669 100644 --- a/include/rabit/internal/rabit-inl.h +++ b/include/rabit/internal/rabit-inl.h @@ -110,6 +110,10 @@ inline bool Init(int argc, char *argv[]) { inline bool Finalize(void) { return engine::Finalize(); } +// get the rank of the previous worker in ring topology +inline int GetRingPrevRank(void) { + return engine::GetEngine()->GetRingPrevRank(); +} // get the rank of current process inline int GetRank(void) { return engine::GetEngine()->GetRank(); @@ -194,6 +198,16 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, } #endif // C++11 +// Performs inplace Allgather +template +inline void Allgather(DType *sendrecvbuf, size_t totalSize, + size_t beginIndex, size_t sizeNodeSlice, + size_t sizePrevSlice) { + engine::Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType), + (beginIndex + sizeNodeSlice) * sizeof(DType), + sizePrevSlice * sizeof(DType)); +} + // print message to the tracker inline void TrackerPrint(const std::string &msg) { engine::GetEngine()->TrackerPrint(msg); diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 33b942d28..040393d0c 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -8,8 +8,9 @@ #define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_DEPRECATE #define NOMINMAX -#include +#include #include +#include #include "allreduce_base.h" namespace rabit { @@ -221,6 +222,12 @@ void AllreduceBase::SetParam(const char *name, const char *val) { timeout_sec = atoi(val); utils::Assert(timeout_sec >= 0, "rabit_timeout_sec should be non negative second"); } + if (!strcmp(name, "rabit_enable_tcp_no_delay")) { + if (!strcmp(val, "true")) + rabit_enable_tcp_no_delay = true; + else + rabit_enable_tcp_no_delay = false; + } } /*! * \brief initialize connection to the tracker @@ -420,11 +427,16 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) { this->parent_index = -1; // setup tree links and ring structure tree_links.plinks.clear(); + int tcpNoDelay = 1; for (size_t i = 0; i < all_links.size(); ++i) { utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket"); // set the socket to non-blocking mode, enable TCP keepalive all_links[i].sock.SetNonBlock(true); all_links[i].sock.SetKeepAlive(true); + if (rabit_enable_tcp_no_delay) { + setsockopt(all_links[i].sock, IPPROTO_TCP, + TCP_NODELAY, reinterpret_cast(&tcpNoDelay), sizeof(tcpNoDelay)); + } if (tree_neighbors.count(all_links[i].rank) != 0) { if (all_links[i].rank == parent_rank) { parent_index = static_cast(tree_links.plinks.size()); diff --git a/src/allreduce_base.h b/src/allreduce_base.h index 00f9d2ef1..07405f21a 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -61,6 +61,10 @@ class AllreduceBase : public IEngine { */ virtual void TrackerPrint(const std::string &msg); + /*! \brief get rank of previous node in ring topology*/ + virtual int GetRingPrevRank(void) const { + return ring_prev->rank; + } /*! \brief get rank */ virtual int GetRank(void) const { return rank; @@ -78,6 +82,35 @@ class AllreduceBase : public IEngine { virtual std::string GetHost(void) const { return host_uri; } + + /*! + * \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf, + * the data provided by current node k is [slice_begin, slice_end), + * the next node's segment must start with slice_end + * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments + * use a ring based algorithm + * + * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually + * \param total_size total size of data to be gathered + * \param slice_begin beginning of the current slice + * \param slice_end end of the current slice + * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size + * \param _file caller file name used to generate unique cache key + * \param _line caller line number used to generate unique cache key + * \param _caller caller function name used to generate unique cache key + */ + virtual void Allgather(void *sendrecvbuf_, size_t total_size, + size_t slice_begin, + size_t slice_end, + size_t size_prev_slice, + const char* _file = _FILE, + const int _line = _LINE, + const char* _caller = _CALLER) { + if (world_size == 1 || world_size == -1) return; + utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size, + slice_begin, slice_end, size_prev_slice) == kSuccess, + "AllgatherRing failed"); + } /*! * \brief perform in-place allreduce, on sendrecvbuf * this function is NOT thread-safe @@ -550,6 +583,8 @@ class AllreduceBase : public IEngine { int timeout_sec = 1800; // flag to enable rabit_timeout bool rabit_timeout = false; + // Enable TCP node delay + bool rabit_enable_tcp_no_delay = false; }; } // namespace engine } // namespace rabit diff --git a/src/allreduce_mock.h b/src/allreduce_mock.h index 428cc5c2f..95f429945 100644 --- a/src/allreduce_mock.h +++ b/src/allreduce_mock.h @@ -25,6 +25,7 @@ class AllreduceMock : public AllreduceRobust { force_local = 0; report_stats = 0; tsum_allreduce = 0.0; + tsum_allgather = 0.0; } // destructor virtual ~AllreduceMock(void) {} @@ -59,6 +60,20 @@ class AllreduceMock : public AllreduceRobust { _file, _line, _caller); tsum_allreduce += utils::GetTime() - tstart; } + virtual void Allgather(void *sendrecvbuf, + size_t total_size, + size_t slice_begin, + size_t slice_end, + size_t size_prev_slice, + const char* _file = _FILE, + const int _line = _LINE, + const char* _caller = _CALLER) { + this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Allgather"); + double tstart = utils::GetTime(); + AllreduceRobust::Allgather(sendrecvbuf, total_size, + slice_begin, slice_end, size_prev_slice); + tsum_allgather += utils::GetTime() - tstart; + } virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root, const char* _file = _FILE, const int _line = _LINE, @@ -69,6 +84,7 @@ class AllreduceMock : public AllreduceRobust { virtual int LoadCheckPoint(Serializable *global_model, Serializable *local_model) { tsum_allreduce = 0.0; + tsum_allgather = 0.0; time_checkpoint = utils::GetTime(); if (force_local == 0) { return AllreduceRobust::LoadCheckPoint(global_model, local_model); @@ -98,10 +114,12 @@ class AllreduceMock : public AllreduceRobust { << ",local_size=" << (local_chkpt[0].length() + local_chkpt[1].length()) << ",check_tcost="<< tcost <<" sec" << ",allreduce_tcost=" << tsum_allreduce << " sec" + << ",allgather_tcost=" << tsum_allgather << " sec" << ",between_chpt=" << tbet_chkpt << "sec\n"; this->TrackerPrint(ss.str()); } tsum_allreduce = 0.0; + tsum_allgather = 0.0; } virtual void LazyCheckPoint(const Serializable *global_model) { @@ -116,6 +134,8 @@ class AllreduceMock : public AllreduceRobust { int report_stats; // sum of allreduce double tsum_allreduce; + // sum of allgather + double tsum_allgather; double time_checkpoint; private: diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index 242faee9f..e3db64c44 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -143,13 +143,84 @@ int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf, size_t siz = 0; void* temp = cachebuf.Query(index, &siz); - _assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index"); - _assert(siz == type_nbytes*count, "cache size stored expected to be same as requested"); - _assert(siz > 0, "cache size should be greater than 0"); + utils::Assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index"); + utils::Assert(siz == type_nbytes*count, "cache size stored expected to be same as requested"); + utils::Assert(siz > 0, "cache size should be greater than 0"); std::memcpy(buf, temp, type_nbytes*count); return 0; } +/*! + * \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf, + * the data provided by current node k is [slice_begin, slice_end), + * the next node's segment must start with slice_end + * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments + * use a ring based algorithm + * + * \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually + * \param total_size total size of data to be gathered + * \param slice_begin beginning of the current slice + * \param slice_end end of the current slice + * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size + * \param _file caller file name used to generate unique cache key + * \param _line caller line number used to generate unique cache key + * \param _caller caller function name used to generate unique cache key + */ +void AllreduceRobust::Allgather(void *sendrecvbuf, + size_t total_size, + size_t slice_begin, + size_t slice_end, + size_t size_prev_slice, + const char* _file, + const int _line, + const char* _caller) { + if (world_size == 1 || world_size == -1) return; + // genreate unique allgather signature + std::string key = std::string(_file) + "::" + std::to_string(_line) + "::" + + std::string(_caller) + "#" +std::to_string(total_size); + + // try fetch bootstrap allgather results from cache + if (!checkpoint_loaded && rabit_bootstrap_cache && + GetBootstrapCache(key, sendrecvbuf, total_size, 1) != -1) return; + + double start = utils::GetTime(); + bool recovered = RecoverExec(sendrecvbuf, total_size, 0, seq_counter, cur_cache_seq); + + if (resbuf.LastSeqNo() != -1 && + (result_buffer_round == -1 || + resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) { + resbuf.DropLast(); + } + + void *temp = resbuf.AllocTemp(total_size, 1); + while (true) { + if (recovered) { + std::memcpy(temp, sendrecvbuf, total_size); break; + } else { + std::memcpy(temp, sendrecvbuf, total_size); + if (CheckAndRecover(TryAllgatherRing(temp, total_size, + slice_begin, slice_end, size_prev_slice))) { + std::memcpy(sendrecvbuf, temp, total_size); break; + } else { + recovered = RecoverExec(sendrecvbuf, total_size, 0, seq_counter, cur_cache_seq); + } + } + } + double delta = utils::GetTime() - start; + // log allgather latency + if (rabit_debug) { + utils::HandleLogInfo("[%d] allgather (%s) finished version %d, seq %d, take %f seconds\n", + rank, key.c_str(), version_number, seq_counter, delta); + } + + // if bootstrap allgather, store and fetch through cache + if (checkpoint_loaded || !rabit_bootstrap_cache) { + resbuf.PushTemp(seq_counter, total_size, 1); + seq_counter += 1; + } else { + SetBootstrapCache(key, sendrecvbuf, total_size, 1); + } +} /*! * \brief perform in-place allreduce, on sendrecvbuf diff --git a/src/allreduce_robust.h b/src/allreduce_robust.h index 0e2fed407..a4bee7c58 100644 --- a/src/allreduce_robust.h +++ b/src/allreduce_robust.h @@ -51,6 +51,29 @@ class AllreduceRobust : public AllreduceBase { */ int GetBootstrapCache(const std::string &key, void *buf, const size_t type_nbytes, const size_t count); + /*! + * \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf, + * the data provided by current node k is [slice_begin, slice_end), + * the next node's segment must start with slice_end + * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments + * use a ring based algorithm + * + * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually + * \param total_size total size of data to be gathered + * \param slice_begin beginning of the current slice + * \param slice_end end of the current slice + * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size + * \param _file caller file name used to generate unique cache key + * \param _line caller line number used to generate unique cache key + * \param _caller caller function name used to generate unique cache key + */ + virtual void Allgather(void *sendrecvbuf_, size_t total_size, + size_t slice_begin, + size_t slice_end, + size_t size_prev_slice, + const char* _file = _FILE, + const int _line = _LINE, + const char* _caller = _CALLER); /*! * \brief perform in-place allreduce, on sendrecvbuf * this function is NOT thread-safe diff --git a/src/c_api.cc b/src/c_api.cc index 0cab3701b..61e5045dc 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -13,7 +13,7 @@ namespace c_api { // helper use to avoid BitOR operator template struct FHelper { - inline static void + static void Allreduce(DType *senrecvbuf_, size_t count, void (*prepare_fun)(void *arg), @@ -25,7 +25,7 @@ struct FHelper { template struct FHelper { - inline static void + static void Allreduce(DType *senrecvbuf_, size_t count, void (*prepare_fun)(void *arg), @@ -35,11 +35,11 @@ struct FHelper { }; template -inline void Allreduce_(void *sendrecvbuf_, - size_t count, - engine::mpi::DataType enum_dtype, - void (*prepare_fun)(void *arg), - void *prepare_arg) { +void Allreduce_(void *sendrecvbuf_, + size_t count, + engine::mpi::DataType enum_dtype, + void (*prepare_fun)(void *arg), + void *prepare_arg) { using namespace engine::mpi; switch (enum_dtype) { case kChar: @@ -85,12 +85,12 @@ inline void Allreduce_(void *sendrecvbuf_, default: utils::Error("unknown data_type"); } } -inline void Allreduce(void *sendrecvbuf, - size_t count, - engine::mpi::DataType enum_dtype, - engine::mpi::OpType enum_op, - void (*prepare_fun)(void *arg), - void *prepare_arg) { +void Allreduce(void *sendrecvbuf, + size_t count, + engine::mpi::DataType enum_dtype, + engine::mpi::OpType enum_op, + void (*prepare_fun)(void *arg), + void *prepare_arg) { using namespace engine::mpi; switch (enum_op) { case kMax: @@ -120,8 +120,45 @@ inline void Allreduce(void *sendrecvbuf, default: utils::Error("unknown enum_op"); } } - - +void Allgather(void *sendrecvbuf_, + size_t total_size, + size_t beginIndex, + size_t size_node_slice, + size_t size_prev_slice, + int enum_dtype) { + using namespace engine::mpi; + size_t type_size = 0; + switch (enum_dtype) { + case kChar: + type_size = sizeof(char); + break; + case kUChar: + type_size = sizeof(unsigned char); + break; + case kInt: + type_size = sizeof(int); + break; + case kUInt: + type_size = sizeof(unsigned); + break; + case kLong: + type_size = sizeof(int64_t); + break; + case kULong: + type_size = sizeof(uint64_t); + break; + case kFloat: + type_size = sizeof(float); + break; + case kDouble: + type_size = sizeof(double); + break; + default: utils::Error("unknown data_type"); + } + engine::Allgather + (sendrecvbuf_, total_size * type_size, beginIndex * type_size, + (beginIndex + size_node_slice) * type_size, size_prev_slice * type_size); +} // wrapper for serialization struct ReadWrapper : public Serializable { @@ -170,6 +207,10 @@ bool RabitFinalize() { return rabit::Finalize(); } +int RabitGetRingPrevRank() { + return rabit::GetRingPrevRank(); +} + int RabitGetRank() { return rabit::GetRank(); } @@ -203,6 +244,21 @@ void RabitBroadcast(void *sendrecv_data, rabit::Broadcast(sendrecv_data, size, root); } +void RabitAllgather(void *sendrecvbuf_, + size_t total_size, + size_t beginIndex, + size_t size_node_slice, + size_t size_prev_slice, + int enum_dtype) { + rabit::c_api::Allgather(sendrecvbuf_, + total_size, + beginIndex, + size_node_slice, + size_prev_slice, + enum_dtype); +} + + void RabitAllreduce(void *sendrecvbuf, size_t count, int enum_dtype, diff --git a/src/engine.cc b/src/engine.cc index a0f8d595e..d43017d1a 100644 --- a/src/engine.cc +++ b/src/engine.cc @@ -84,6 +84,15 @@ IEngine *GetEngine() { } } +// perform in-place allgather, on sendrecvbuf +void Allgather(void *sendrecvbuf_, size_t total_size, + size_t slice_begin, + size_t slice_end, + size_t size_prev_slice) { + GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin, slice_end, size_prev_slice); +} + + // perform in-place allreduce, on sendrecvbuf void Allreduce_(void *sendrecvbuf, size_t type_nbytes, diff --git a/src/engine_empty.cc b/src/engine_empty.cc index 0a7926628..b548c1782 100644 --- a/src/engine_empty.cc +++ b/src/engine_empty.cc @@ -25,6 +25,19 @@ class EmptyEngine : public IEngine { EmptyEngine(void) { version_number = 0; } + virtual void Allgather(void *sendrecvbuf_, + size_t total_size, + size_t slice_begin, + size_t slice_end, + size_t size_prev_slice, + const char* _file, + const int _line, + const char* _caller) { + utils::Error("EmptyEngine:: Allgather is not supported"); + } + virtual int GetRingPrevRank(void) const { + utils::Error("EmptyEngine:: GetRingPrevRank is not supported"); + } virtual void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, diff --git a/src/engine_mpi.cc b/src/engine_mpi.cc index 3379765c1..f61770802 100644 --- a/src/engine_mpi.cc +++ b/src/engine_mpi.cc @@ -27,6 +27,16 @@ class MPIEngine : public IEngine { MPIEngine(void) { version_number = 0; } + virtual void Allgather(void *sendrecvbuf_, + size_t total_size, + size_t slice_begin, + size_t slice_end, + size_t size_prev_slice, + const char* _file, + const int _line, + const char* _caller) { + utils::Error("MPIEngine:: Allgather is not supported"); + } virtual void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, @@ -39,6 +49,9 @@ class MPIEngine : public IEngine { utils::Error("MPIEngine:: Allreduce is not supported,"\ "use Allreduce_ instead"); } + virtual int GetRingPrevRank(void) const { + utils::Error("MPIEngine:: GetRingPrevRank is not supported"); + } virtual void Broadcast(void *sendrecvbuf_, size_t size, int root, const char* _file, const int _line, const char* _caller) { diff --git a/test/cpp/allreduce_base_test.cpp b/test/cpp/allreduce_base_test.cpp new file mode 100644 index 000000000..65a3dd50b --- /dev/null +++ b/test/cpp/allreduce_base_test.cpp @@ -0,0 +1,66 @@ +#define RABIT_CXXTESTDEFS_H +#include + +#include +#include +#include "../../src/allreduce_base.h" + +TEST(allreduce_base, init_task) +{ + rabit::engine::AllreduceBase base; + + std::string rabit_task_id = "rabit_task_id=1"; + char cmd[rabit_task_id.size()+1]; + std::copy(rabit_task_id.begin(), rabit_task_id.end(), cmd); + cmd[rabit_task_id.size()] = '\0'; + + char* argv[] = {cmd}; + base.Init(1, argv); + EXPECT_EQ(base.task_id, "1"); +} + +TEST(allreduce_base, init_with_cache_on) +{ + rabit::engine::AllreduceBase base; + + std::string rabit_task_id = "rabit_task_id=1"; + char cmd[rabit_task_id.size()+1]; + std::copy(rabit_task_id.begin(), rabit_task_id.end(), cmd); + cmd[rabit_task_id.size()] = '\0'; + + std::string rabit_bootstrap_cache = "rabit_bootstrap_cache=1"; + char cmd2[rabit_bootstrap_cache.size()+1]; + std::copy(rabit_bootstrap_cache.begin(), rabit_bootstrap_cache.end(), cmd2); + cmd2[rabit_bootstrap_cache.size()] = '\0'; + + std::string rabit_debug = "rabit_debug=1"; + char cmd3[rabit_debug.size()+1]; + std::copy(rabit_debug.begin(), rabit_debug.end(), cmd3); + cmd3[rabit_debug.size()] = '\0'; + + char* argv[] = {cmd, cmd2, cmd3}; + base.Init(3, argv); + EXPECT_EQ(base.task_id, "1"); + EXPECT_EQ(base.rabit_bootstrap_cache, 1); + EXPECT_EQ(base.rabit_debug, 1); +} + +TEST(allreduce_base, init_with_ring_reduce) +{ + rabit::engine::AllreduceBase base; + + std::string rabit_task_id = "rabit_task_id=1"; + char cmd[rabit_task_id.size()+1]; + std::copy(rabit_task_id.begin(), rabit_task_id.end(), cmd); + cmd[rabit_task_id.size()] = '\0'; + + std::string rabit_reduce_ring_mincount = "rabit_reduce_ring_mincount=1"; + char cmd2[rabit_reduce_ring_mincount.size()+1]; + std::copy(rabit_reduce_ring_mincount.begin(), rabit_reduce_ring_mincount.end(), cmd2); + cmd2[rabit_reduce_ring_mincount.size()] = '\0'; + + char* argv[] = {cmd, cmd2}; + base.Init(2, argv); + EXPECT_EQ(base.task_id, "1"); + EXPECT_EQ(base.reduce_ring_mincount, 1); +} diff --git a/test/cpp/allreduce_mock_test.cpp b/test/cpp/allreduce_mock_test.cpp new file mode 100644 index 000000000..ec3190c96 --- /dev/null +++ b/test/cpp/allreduce_mock_test.cpp @@ -0,0 +1,51 @@ +#define RABIT_CXXTESTDEFS_H +#include + +#include +#include +#include "../../src/allreduce_mock.h" + +TEST(allreduce_mock, mock_allreduce) +{ + rabit::engine::AllreduceMock m; + + std::string mock_str = "mock=0,0,0,0"; + char cmd[mock_str.size()+1]; + std::copy(mock_str.begin(), mock_str.end(), cmd); + cmd[mock_str.size()] = '\0'; + + char* argv[] = {cmd}; + m.Init(1, argv); + m.rank = 0; + EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), ""); +} + +TEST(allreduce_mock, mock_broadcast) +{ + rabit::engine::AllreduceMock m; + std::string mock_str = "mock=0,1,2,0"; + char cmd[mock_str.size()+1]; + std::copy(mock_str.begin(), mock_str.end(), cmd); + cmd[mock_str.size()] = '\0'; + char* argv[] = {cmd}; + m.Init(1, argv); + m.rank = 0; + m.version_number=1; + m.seq_counter=2; + EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), ""); +} + +TEST(allreduce_mock, mock_gather) +{ + rabit::engine::AllreduceMock m; + std::string mock_str = "mock=3,13,22,0"; + char cmd[mock_str.size()+1]; + std::copy(mock_str.begin(), mock_str.end(), cmd); + cmd[mock_str.size()] = '\0'; + char* argv[] = {cmd}; + m.Init(1, argv); + m.rank = 3; + m.version_number=13; + m.seq_counter=22; + EXPECT_EXIT(m.Allgather(nullptr,0,0,0,0), ::testing::ExitedWithCode(255), ""); +} diff --git a/test/model_recover.cc b/test/model_recover.cc index 48a4da99c..181638c07 100644 --- a/test/model_recover.cc +++ b/test/model_recover.cc @@ -26,7 +26,7 @@ class Model : public rabit::Serializable { } }; -inline void TestMax(Model *model, int ntrial, int iter) { +inline void TestMax(Model *model, int iter) { int rank = rabit::GetRank(); int nproc = rabit::GetWorldSize(); const int z = iter + 111; @@ -47,7 +47,7 @@ inline void TestMax(Model *model, int ntrial, int iter) { model->data = ndata; } -inline void TestSum(Model *model, int ntrial, int iter) { +inline void TestSum(Model *model, int iter) { int rank = rabit::GetRank(); int nproc = rabit::GetWorldSize(); const int z = 131 + iter; @@ -69,7 +69,30 @@ inline void TestSum(Model *model, int ntrial, int iter) { model->data = ndata; } -inline void TestBcast(size_t n, int root, int ntrial, int iter) { +inline void TestAllgather(Model *model, int iter) { + int rank = rabit::GetRank(); + int nproc = rabit::GetWorldSize(); + const int z = 131 + iter; + + std::vector ndata(model->data.size() * nproc); + size_t beginSlice = rank * model->data.size(); + for (size_t i = 0; i < model->data.size(); ++i) { + ndata[beginSlice + i] = (i * (rank+1)) % z + model->data[i]; + } + Allgather(&ndata[0], ndata.size(), beginSlice, + model->data.size(), model->data.size()); + + for (size_t i = 0; i < ndata.size(); ++i) { + int curRank = i / model->data.size(); + int remainder = i % model->data.size(); + float data = (remainder * (curRank+1)) % z + model->data[remainder]; + utils::Check(fabsf(data - ndata[i]) < 1e-5 , + "[%d] TestAllgather check failure, local=%g, allgatherring=%g", rank, data, ndata[i]); + } + model->data = ndata; +} + +inline void TestBcast(size_t n, int root) { int rank = rabit::GetRank(); std::string s; s.resize(n); for (size_t i = 0; i < n; ++i) { @@ -113,19 +136,22 @@ int main(int argc, char *argv[]) { printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter); for (int r = iter; r < 3; ++r) { - TestMax(&model, ntrial, r); + TestMax(&model, r); printf("[%d] !!!TestMax pass, iter=%d\n", rank, r); int step = std::max(nproc / 3, 1); for (int i = 0; i < nproc; i += step) { - TestBcast(n, i, ntrial, r); + TestBcast(n, i); } printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r); - TestSum(&model, ntrial, r); + TestSum(&model, r); printf("[%d] !!!TestSum pass, iter=%d\n", rank, r); + TestAllgather(&model, r); + printf("[%d] !!!TestAllgather pass, iter=%d\n", rank, r); rabit::CheckPoint(&model); printf("[%d] !!!Checkpoint pass, iter=%d\n", rank, r); } rabit::Finalize(); return 0; } +