Expose RabitAllGatherRing and RabitGetRingPrevRank (#113)

* add unittests

* Expose RabitAllGatherRing and RabitGetRingPrevRank

* Enabled TCP_NODELAY to decrease latency
This commit is contained in:
nateagr 2019-11-12 12:55:32 +01:00 committed by Jiaming Yuan
parent 90e2239372
commit 1907b25cd0
15 changed files with 509 additions and 25 deletions

View File

@ -43,6 +43,12 @@ RABIT_DLL bool RabitInit(int argc, char *argv[]);
*/
RABIT_DLL bool RabitFinalize(void);
/*!
* \brief get rank of previous process in ring topology
* \return rank number of worker
* */
RABIT_DLL int RabitGetRingPrevRank(void);
/*!
* \brief get rank of current process
* \return rank number of worker
@ -87,6 +93,30 @@ RABIT_DLL void RabitGetProcessorName(char *out_name,
*/
RABIT_DLL void RabitBroadcast(void *sendrecv_data,
rbt_ulong size, int root);
/*!
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param beginIndex beginning of the current slice in sendrecvbuf of type enum_dtype
* \param size_node_slice size of the current node slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \param enum_dtype the enumeration of data type, see rabit::engine::mpi::DataType in engine.h of rabit include
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
RABIT_DLL void RabitAllgather(void *sendrecvbuf,
size_t total_size,
size_t beginIndex,
size_t size_node_slice,
size_t size_prev_slice,
int enum_dtype);
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe

View File

@ -53,6 +53,29 @@ class IEngine {
const MPI::Datatype &dtype);
/*! \brief virtual destructor */
virtual ~IEngine() {}
/*!
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Allgather(void *sendrecvbuf, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) = 0;
/*!
* \brief performs in-place Allreduce, on sendrecvbuf
* this function is NOT thread-safe
@ -165,6 +188,8 @@ class IEngine {
* \sa LoadCheckPoint, CheckPoint
*/
virtual int VersionNumber(void) const = 0;
/*! \brief gets rank of previous node in ring topology */
virtual int GetRingPrevRank(void) const = 0;
/*! \brief gets rank of current node */
virtual int GetRank(void) const = 0;
/*! \brief gets total number of nodes */
@ -212,6 +237,26 @@ enum DataType {
kULongLong = 9
};
} // namespace mpi
/*!
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
void Allgather(void* sendrecvbuf,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice);
/*!
* \brief perform in-place Allreduce, on sendrecvbuf
* this is an internal function used by rabit to be able to compile with MPI

View File

@ -110,6 +110,10 @@ inline bool Init(int argc, char *argv[]) {
inline bool Finalize(void) {
return engine::Finalize();
}
// get the rank of the previous worker in ring topology
inline int GetRingPrevRank(void) {
return engine::GetEngine()->GetRingPrevRank();
}
// get the rank of current process
inline int GetRank(void) {
return engine::GetEngine()->GetRank();
@ -194,6 +198,16 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
}
#endif // C++11
// Performs inplace Allgather
template<typename DType>
inline void Allgather(DType *sendrecvbuf, size_t totalSize,
size_t beginIndex, size_t sizeNodeSlice,
size_t sizePrevSlice) {
engine::Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType),
(beginIndex + sizeNodeSlice) * sizeof(DType),
sizePrevSlice * sizeof(DType));
}
// print message to the tracker
inline void TrackerPrint(const std::string &msg) {
engine::GetEngine()->TrackerPrint(msg);

View File

@ -8,8 +8,9 @@
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include <map>
#include <netinet/tcp.h>
#include <cstring>
#include <map>
#include "allreduce_base.h"
namespace rabit {
@ -221,6 +222,12 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
timeout_sec = atoi(val);
utils::Assert(timeout_sec >= 0, "rabit_timeout_sec should be non negative second");
}
if (!strcmp(name, "rabit_enable_tcp_no_delay")) {
if (!strcmp(val, "true"))
rabit_enable_tcp_no_delay = true;
else
rabit_enable_tcp_no_delay = false;
}
}
/*!
* \brief initialize connection to the tracker
@ -420,11 +427,16 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
this->parent_index = -1;
// setup tree links and ring structure
tree_links.plinks.clear();
int tcpNoDelay = 1;
for (size_t i = 0; i < all_links.size(); ++i) {
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode, enable TCP keepalive
all_links[i].sock.SetNonBlock(true);
all_links[i].sock.SetKeepAlive(true);
if (rabit_enable_tcp_no_delay) {
setsockopt(all_links[i].sock, IPPROTO_TCP,
TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
}
if (tree_neighbors.count(all_links[i].rank) != 0) {
if (all_links[i].rank == parent_rank) {
parent_index = static_cast<int>(tree_links.plinks.size());

View File

@ -61,6 +61,10 @@ class AllreduceBase : public IEngine {
*/
virtual void TrackerPrint(const std::string &msg);
/*! \brief get rank of previous node in ring topology*/
virtual int GetRingPrevRank(void) const {
return ring_prev->rank;
}
/*! \brief get rank */
virtual int GetRank(void) const {
return rank;
@ -78,6 +82,35 @@ class AllreduceBase : public IEngine {
virtual std::string GetHost(void) const {
return host_uri;
}
/*!
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Allgather(void *sendrecvbuf_, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
if (world_size == 1 || world_size == -1) return;
utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size,
slice_begin, slice_end, size_prev_slice) == kSuccess,
"AllgatherRing failed");
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe
@ -550,6 +583,8 @@ class AllreduceBase : public IEngine {
int timeout_sec = 1800;
// flag to enable rabit_timeout
bool rabit_timeout = false;
// Enable TCP node delay
bool rabit_enable_tcp_no_delay = false;
};
} // namespace engine
} // namespace rabit

View File

@ -25,6 +25,7 @@ class AllreduceMock : public AllreduceRobust {
force_local = 0;
report_stats = 0;
tsum_allreduce = 0.0;
tsum_allgather = 0.0;
}
// destructor
virtual ~AllreduceMock(void) {}
@ -59,6 +60,20 @@ class AllreduceMock : public AllreduceRobust {
_file, _line, _caller);
tsum_allreduce += utils::GetTime() - tstart;
}
virtual void Allgather(void *sendrecvbuf,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Allgather");
double tstart = utils::GetTime();
AllreduceRobust::Allgather(sendrecvbuf, total_size,
slice_begin, slice_end, size_prev_slice);
tsum_allgather += utils::GetTime() - tstart;
}
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char* _file = _FILE,
const int _line = _LINE,
@ -69,6 +84,7 @@ class AllreduceMock : public AllreduceRobust {
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model) {
tsum_allreduce = 0.0;
tsum_allgather = 0.0;
time_checkpoint = utils::GetTime();
if (force_local == 0) {
return AllreduceRobust::LoadCheckPoint(global_model, local_model);
@ -98,10 +114,12 @@ class AllreduceMock : public AllreduceRobust {
<< ",local_size=" << (local_chkpt[0].length() + local_chkpt[1].length())
<< ",check_tcost="<< tcost <<" sec"
<< ",allreduce_tcost=" << tsum_allreduce << " sec"
<< ",allgather_tcost=" << tsum_allgather << " sec"
<< ",between_chpt=" << tbet_chkpt << "sec\n";
this->TrackerPrint(ss.str());
}
tsum_allreduce = 0.0;
tsum_allgather = 0.0;
}
virtual void LazyCheckPoint(const Serializable *global_model) {
@ -116,6 +134,8 @@ class AllreduceMock : public AllreduceRobust {
int report_stats;
// sum of allreduce
double tsum_allreduce;
// sum of allgather
double tsum_allgather;
double time_checkpoint;
private:

View File

@ -143,13 +143,84 @@ int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf,
size_t siz = 0;
void* temp = cachebuf.Query(index, &siz);
_assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index");
_assert(siz == type_nbytes*count, "cache size stored expected to be same as requested");
_assert(siz > 0, "cache size should be greater than 0");
utils::Assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index");
utils::Assert(siz == type_nbytes*count, "cache size stored expected to be same as requested");
utils::Assert(siz > 0, "cache size should be greater than 0");
std::memcpy(buf, temp, type_nbytes*count);
return 0;
}
/*!
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
void AllreduceRobust::Allgather(void *sendrecvbuf,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file,
const int _line,
const char* _caller) {
if (world_size == 1 || world_size == -1) return;
// genreate unique allgather signature
std::string key = std::string(_file) + "::" + std::to_string(_line) + "::"
+ std::string(_caller) + "#" +std::to_string(total_size);
// try fetch bootstrap allgather results from cache
if (!checkpoint_loaded && rabit_bootstrap_cache &&
GetBootstrapCache(key, sendrecvbuf, total_size, 1) != -1) return;
double start = utils::GetTime();
bool recovered = RecoverExec(sendrecvbuf, total_size, 0, seq_counter, cur_cache_seq);
if (resbuf.LastSeqNo() != -1 &&
(result_buffer_round == -1 ||
resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
resbuf.DropLast();
}
void *temp = resbuf.AllocTemp(total_size, 1);
while (true) {
if (recovered) {
std::memcpy(temp, sendrecvbuf, total_size); break;
} else {
std::memcpy(temp, sendrecvbuf, total_size);
if (CheckAndRecover(TryAllgatherRing(temp, total_size,
slice_begin, slice_end, size_prev_slice))) {
std::memcpy(sendrecvbuf, temp, total_size); break;
} else {
recovered = RecoverExec(sendrecvbuf, total_size, 0, seq_counter, cur_cache_seq);
}
}
}
double delta = utils::GetTime() - start;
// log allgather latency
if (rabit_debug) {
utils::HandleLogInfo("[%d] allgather (%s) finished version %d, seq %d, take %f seconds\n",
rank, key.c_str(), version_number, seq_counter, delta);
}
// if bootstrap allgather, store and fetch through cache
if (checkpoint_loaded || !rabit_bootstrap_cache) {
resbuf.PushTemp(seq_counter, total_size, 1);
seq_counter += 1;
} else {
SetBootstrapCache(key, sendrecvbuf, total_size, 1);
}
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf

View File

@ -51,6 +51,29 @@ class AllreduceRobust : public AllreduceBase {
*/
int GetBootstrapCache(const std::string &key, void *buf, const size_t type_nbytes,
const size_t count);
/*!
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Allgather(void *sendrecvbuf_, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe

View File

@ -13,7 +13,7 @@ namespace c_api {
// helper use to avoid BitOR operator
template<typename OP, typename DType>
struct FHelper {
inline static void
static void
Allreduce(DType *senrecvbuf_,
size_t count,
void (*prepare_fun)(void *arg),
@ -25,7 +25,7 @@ struct FHelper {
template<typename DType>
struct FHelper<op::BitOR, DType> {
inline static void
static void
Allreduce(DType *senrecvbuf_,
size_t count,
void (*prepare_fun)(void *arg),
@ -35,11 +35,11 @@ struct FHelper<op::BitOR, DType> {
};
template<typename OP>
inline void Allreduce_(void *sendrecvbuf_,
size_t count,
engine::mpi::DataType enum_dtype,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
void Allreduce_(void *sendrecvbuf_,
size_t count,
engine::mpi::DataType enum_dtype,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
using namespace engine::mpi;
switch (enum_dtype) {
case kChar:
@ -85,12 +85,12 @@ inline void Allreduce_(void *sendrecvbuf_,
default: utils::Error("unknown data_type");
}
}
inline void Allreduce(void *sendrecvbuf,
size_t count,
engine::mpi::DataType enum_dtype,
engine::mpi::OpType enum_op,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
void Allreduce(void *sendrecvbuf,
size_t count,
engine::mpi::DataType enum_dtype,
engine::mpi::OpType enum_op,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
using namespace engine::mpi;
switch (enum_op) {
case kMax:
@ -120,8 +120,45 @@ inline void Allreduce(void *sendrecvbuf,
default: utils::Error("unknown enum_op");
}
}
void Allgather(void *sendrecvbuf_,
size_t total_size,
size_t beginIndex,
size_t size_node_slice,
size_t size_prev_slice,
int enum_dtype) {
using namespace engine::mpi;
size_t type_size = 0;
switch (enum_dtype) {
case kChar:
type_size = sizeof(char);
break;
case kUChar:
type_size = sizeof(unsigned char);
break;
case kInt:
type_size = sizeof(int);
break;
case kUInt:
type_size = sizeof(unsigned);
break;
case kLong:
type_size = sizeof(int64_t);
break;
case kULong:
type_size = sizeof(uint64_t);
break;
case kFloat:
type_size = sizeof(float);
break;
case kDouble:
type_size = sizeof(double);
break;
default: utils::Error("unknown data_type");
}
engine::Allgather
(sendrecvbuf_, total_size * type_size, beginIndex * type_size,
(beginIndex + size_node_slice) * type_size, size_prev_slice * type_size);
}
// wrapper for serialization
struct ReadWrapper : public Serializable {
@ -170,6 +207,10 @@ bool RabitFinalize() {
return rabit::Finalize();
}
int RabitGetRingPrevRank() {
return rabit::GetRingPrevRank();
}
int RabitGetRank() {
return rabit::GetRank();
}
@ -203,6 +244,21 @@ void RabitBroadcast(void *sendrecv_data,
rabit::Broadcast(sendrecv_data, size, root);
}
void RabitAllgather(void *sendrecvbuf_,
size_t total_size,
size_t beginIndex,
size_t size_node_slice,
size_t size_prev_slice,
int enum_dtype) {
rabit::c_api::Allgather(sendrecvbuf_,
total_size,
beginIndex,
size_node_slice,
size_prev_slice,
enum_dtype);
}
void RabitAllreduce(void *sendrecvbuf,
size_t count,
int enum_dtype,

View File

@ -84,6 +84,15 @@ IEngine *GetEngine() {
}
}
// perform in-place allgather, on sendrecvbuf
void Allgather(void *sendrecvbuf_, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice) {
GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin, slice_end, size_prev_slice);
}
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf,
size_t type_nbytes,

View File

@ -25,6 +25,19 @@ class EmptyEngine : public IEngine {
EmptyEngine(void) {
version_number = 0;
}
virtual void Allgather(void *sendrecvbuf_,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file,
const int _line,
const char* _caller) {
utils::Error("EmptyEngine:: Allgather is not supported");
}
virtual int GetRingPrevRank(void) const {
utils::Error("EmptyEngine:: GetRingPrevRank is not supported");
}
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,

View File

@ -27,6 +27,16 @@ class MPIEngine : public IEngine {
MPIEngine(void) {
version_number = 0;
}
virtual void Allgather(void *sendrecvbuf_,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file,
const int _line,
const char* _caller) {
utils::Error("MPIEngine:: Allgather is not supported");
}
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
@ -39,6 +49,9 @@ class MPIEngine : public IEngine {
utils::Error("MPIEngine:: Allreduce is not supported,"\
"use Allreduce_ instead");
}
virtual int GetRingPrevRank(void) const {
utils::Error("MPIEngine:: GetRingPrevRank is not supported");
}
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root,
const char* _file, const int _line,
const char* _caller) {

View File

@ -0,0 +1,66 @@
#define RABIT_CXXTESTDEFS_H
#include <gtest/gtest.h>
#include <string>
#include <iostream>
#include "../../src/allreduce_base.h"
TEST(allreduce_base, init_task)
{
rabit::engine::AllreduceBase base;
std::string rabit_task_id = "rabit_task_id=1";
char cmd[rabit_task_id.size()+1];
std::copy(rabit_task_id.begin(), rabit_task_id.end(), cmd);
cmd[rabit_task_id.size()] = '\0';
char* argv[] = {cmd};
base.Init(1, argv);
EXPECT_EQ(base.task_id, "1");
}
TEST(allreduce_base, init_with_cache_on)
{
rabit::engine::AllreduceBase base;
std::string rabit_task_id = "rabit_task_id=1";
char cmd[rabit_task_id.size()+1];
std::copy(rabit_task_id.begin(), rabit_task_id.end(), cmd);
cmd[rabit_task_id.size()] = '\0';
std::string rabit_bootstrap_cache = "rabit_bootstrap_cache=1";
char cmd2[rabit_bootstrap_cache.size()+1];
std::copy(rabit_bootstrap_cache.begin(), rabit_bootstrap_cache.end(), cmd2);
cmd2[rabit_bootstrap_cache.size()] = '\0';
std::string rabit_debug = "rabit_debug=1";
char cmd3[rabit_debug.size()+1];
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd3);
cmd3[rabit_debug.size()] = '\0';
char* argv[] = {cmd, cmd2, cmd3};
base.Init(3, argv);
EXPECT_EQ(base.task_id, "1");
EXPECT_EQ(base.rabit_bootstrap_cache, 1);
EXPECT_EQ(base.rabit_debug, 1);
}
TEST(allreduce_base, init_with_ring_reduce)
{
rabit::engine::AllreduceBase base;
std::string rabit_task_id = "rabit_task_id=1";
char cmd[rabit_task_id.size()+1];
std::copy(rabit_task_id.begin(), rabit_task_id.end(), cmd);
cmd[rabit_task_id.size()] = '\0';
std::string rabit_reduce_ring_mincount = "rabit_reduce_ring_mincount=1";
char cmd2[rabit_reduce_ring_mincount.size()+1];
std::copy(rabit_reduce_ring_mincount.begin(), rabit_reduce_ring_mincount.end(), cmd2);
cmd2[rabit_reduce_ring_mincount.size()] = '\0';
char* argv[] = {cmd, cmd2};
base.Init(2, argv);
EXPECT_EQ(base.task_id, "1");
EXPECT_EQ(base.reduce_ring_mincount, 1);
}

View File

@ -0,0 +1,51 @@
#define RABIT_CXXTESTDEFS_H
#include <gtest/gtest.h>
#include <string>
#include <iostream>
#include "../../src/allreduce_mock.h"
TEST(allreduce_mock, mock_allreduce)
{
rabit::engine::AllreduceMock m;
std::string mock_str = "mock=0,0,0,0";
char cmd[mock_str.size()+1];
std::copy(mock_str.begin(), mock_str.end(), cmd);
cmd[mock_str.size()] = '\0';
char* argv[] = {cmd};
m.Init(1, argv);
m.rank = 0;
EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), "");
}
TEST(allreduce_mock, mock_broadcast)
{
rabit::engine::AllreduceMock m;
std::string mock_str = "mock=0,1,2,0";
char cmd[mock_str.size()+1];
std::copy(mock_str.begin(), mock_str.end(), cmd);
cmd[mock_str.size()] = '\0';
char* argv[] = {cmd};
m.Init(1, argv);
m.rank = 0;
m.version_number=1;
m.seq_counter=2;
EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), "");
}
TEST(allreduce_mock, mock_gather)
{
rabit::engine::AllreduceMock m;
std::string mock_str = "mock=3,13,22,0";
char cmd[mock_str.size()+1];
std::copy(mock_str.begin(), mock_str.end(), cmd);
cmd[mock_str.size()] = '\0';
char* argv[] = {cmd};
m.Init(1, argv);
m.rank = 3;
m.version_number=13;
m.seq_counter=22;
EXPECT_EXIT(m.Allgather(nullptr,0,0,0,0), ::testing::ExitedWithCode(255), "");
}

View File

@ -26,7 +26,7 @@ class Model : public rabit::Serializable {
}
};
inline void TestMax(Model *model, int ntrial, int iter) {
inline void TestMax(Model *model, int iter) {
int rank = rabit::GetRank();
int nproc = rabit::GetWorldSize();
const int z = iter + 111;
@ -47,7 +47,7 @@ inline void TestMax(Model *model, int ntrial, int iter) {
model->data = ndata;
}
inline void TestSum(Model *model, int ntrial, int iter) {
inline void TestSum(Model *model, int iter) {
int rank = rabit::GetRank();
int nproc = rabit::GetWorldSize();
const int z = 131 + iter;
@ -69,7 +69,30 @@ inline void TestSum(Model *model, int ntrial, int iter) {
model->data = ndata;
}
inline void TestBcast(size_t n, int root, int ntrial, int iter) {
inline void TestAllgather(Model *model, int iter) {
int rank = rabit::GetRank();
int nproc = rabit::GetWorldSize();
const int z = 131 + iter;
std::vector<float> ndata(model->data.size() * nproc);
size_t beginSlice = rank * model->data.size();
for (size_t i = 0; i < model->data.size(); ++i) {
ndata[beginSlice + i] = (i * (rank+1)) % z + model->data[i];
}
Allgather(&ndata[0], ndata.size(), beginSlice,
model->data.size(), model->data.size());
for (size_t i = 0; i < ndata.size(); ++i) {
int curRank = i / model->data.size();
int remainder = i % model->data.size();
float data = (remainder * (curRank+1)) % z + model->data[remainder];
utils::Check(fabsf(data - ndata[i]) < 1e-5 ,
"[%d] TestAllgather check failure, local=%g, allgatherring=%g", rank, data, ndata[i]);
}
model->data = ndata;
}
inline void TestBcast(size_t n, int root) {
int rank = rabit::GetRank();
std::string s; s.resize(n);
for (size_t i = 0; i < n; ++i) {
@ -113,19 +136,22 @@ int main(int argc, char *argv[]) {
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
for (int r = iter; r < 3; ++r) {
TestMax(&model, ntrial, r);
TestMax(&model, r);
printf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
int step = std::max(nproc / 3, 1);
for (int i = 0; i < nproc; i += step) {
TestBcast(n, i, ntrial, r);
TestBcast(n, i);
}
printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
TestSum(&model, ntrial, r);
TestSum(&model, r);
printf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
TestAllgather(&model, r);
printf("[%d] !!!TestAllgather pass, iter=%d\n", rank, r);
rabit::CheckPoint(&model);
printf("[%d] !!!Checkpoint pass, iter=%d\n", rank, r);
}
rabit::Finalize();
return 0;
}