Expose RabitAllGatherRing and RabitGetRingPrevRank (#113)
* add unittests * Expose RabitAllGatherRing and RabitGetRingPrevRank * Enabled TCP_NODELAY to decrease latency
This commit is contained in:
parent
90e2239372
commit
1907b25cd0
@ -43,6 +43,12 @@ RABIT_DLL bool RabitInit(int argc, char *argv[]);
|
||||
*/
|
||||
RABIT_DLL bool RabitFinalize(void);
|
||||
|
||||
/*!
|
||||
* \brief get rank of previous process in ring topology
|
||||
* \return rank number of worker
|
||||
* */
|
||||
RABIT_DLL int RabitGetRingPrevRank(void);
|
||||
|
||||
/*!
|
||||
* \brief get rank of current process
|
||||
* \return rank number of worker
|
||||
@ -87,6 +93,30 @@ RABIT_DLL void RabitGetProcessorName(char *out_name,
|
||||
*/
|
||||
RABIT_DLL void RabitBroadcast(void *sendrecv_data,
|
||||
rbt_ulong size, int root);
|
||||
|
||||
/*!
|
||||
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param beginIndex beginning of the current slice in sendrecvbuf of type enum_dtype
|
||||
* \param size_node_slice size of the current node slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \param enum_dtype the enumeration of data type, see rabit::engine::mpi::DataType in engine.h of rabit include
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
RABIT_DLL void RabitAllgather(void *sendrecvbuf,
|
||||
size_t total_size,
|
||||
size_t beginIndex,
|
||||
size_t size_node_slice,
|
||||
size_t size_prev_slice,
|
||||
int enum_dtype);
|
||||
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
|
||||
@ -53,6 +53,29 @@ class IEngine {
|
||||
const MPI::Datatype &dtype);
|
||||
/*! \brief virtual destructor */
|
||||
virtual ~IEngine() {}
|
||||
/*!
|
||||
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Allgather(void *sendrecvbuf, size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER) = 0;
|
||||
/*!
|
||||
* \brief performs in-place Allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
@ -165,6 +188,8 @@ class IEngine {
|
||||
* \sa LoadCheckPoint, CheckPoint
|
||||
*/
|
||||
virtual int VersionNumber(void) const = 0;
|
||||
/*! \brief gets rank of previous node in ring topology */
|
||||
virtual int GetRingPrevRank(void) const = 0;
|
||||
/*! \brief gets rank of current node */
|
||||
virtual int GetRank(void) const = 0;
|
||||
/*! \brief gets total number of nodes */
|
||||
@ -212,6 +237,26 @@ enum DataType {
|
||||
kULongLong = 9
|
||||
};
|
||||
} // namespace mpi
|
||||
/*!
|
||||
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
void Allgather(void* sendrecvbuf,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice);
|
||||
/*!
|
||||
* \brief perform in-place Allreduce, on sendrecvbuf
|
||||
* this is an internal function used by rabit to be able to compile with MPI
|
||||
|
||||
@ -110,6 +110,10 @@ inline bool Init(int argc, char *argv[]) {
|
||||
inline bool Finalize(void) {
|
||||
return engine::Finalize();
|
||||
}
|
||||
// get the rank of the previous worker in ring topology
|
||||
inline int GetRingPrevRank(void) {
|
||||
return engine::GetEngine()->GetRingPrevRank();
|
||||
}
|
||||
// get the rank of current process
|
||||
inline int GetRank(void) {
|
||||
return engine::GetEngine()->GetRank();
|
||||
@ -194,6 +198,16 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
}
|
||||
#endif // C++11
|
||||
|
||||
// Performs inplace Allgather
|
||||
template<typename DType>
|
||||
inline void Allgather(DType *sendrecvbuf, size_t totalSize,
|
||||
size_t beginIndex, size_t sizeNodeSlice,
|
||||
size_t sizePrevSlice) {
|
||||
engine::Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType),
|
||||
(beginIndex + sizeNodeSlice) * sizeof(DType),
|
||||
sizePrevSlice * sizeof(DType));
|
||||
}
|
||||
|
||||
// print message to the tracker
|
||||
inline void TrackerPrint(const std::string &msg) {
|
||||
engine::GetEngine()->TrackerPrint(msg);
|
||||
|
||||
@ -8,8 +8,9 @@
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#define NOMINMAX
|
||||
#include <map>
|
||||
#include <netinet/tcp.h>
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
#include "allreduce_base.h"
|
||||
|
||||
namespace rabit {
|
||||
@ -221,6 +222,12 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
timeout_sec = atoi(val);
|
||||
utils::Assert(timeout_sec >= 0, "rabit_timeout_sec should be non negative second");
|
||||
}
|
||||
if (!strcmp(name, "rabit_enable_tcp_no_delay")) {
|
||||
if (!strcmp(val, "true"))
|
||||
rabit_enable_tcp_no_delay = true;
|
||||
else
|
||||
rabit_enable_tcp_no_delay = false;
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief initialize connection to the tracker
|
||||
@ -420,11 +427,16 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
this->parent_index = -1;
|
||||
// setup tree links and ring structure
|
||||
tree_links.plinks.clear();
|
||||
int tcpNoDelay = 1;
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
|
||||
// set the socket to non-blocking mode, enable TCP keepalive
|
||||
all_links[i].sock.SetNonBlock(true);
|
||||
all_links[i].sock.SetKeepAlive(true);
|
||||
if (rabit_enable_tcp_no_delay) {
|
||||
setsockopt(all_links[i].sock, IPPROTO_TCP,
|
||||
TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
|
||||
}
|
||||
if (tree_neighbors.count(all_links[i].rank) != 0) {
|
||||
if (all_links[i].rank == parent_rank) {
|
||||
parent_index = static_cast<int>(tree_links.plinks.size());
|
||||
|
||||
@ -61,6 +61,10 @@ class AllreduceBase : public IEngine {
|
||||
*/
|
||||
virtual void TrackerPrint(const std::string &msg);
|
||||
|
||||
/*! \brief get rank of previous node in ring topology*/
|
||||
virtual int GetRingPrevRank(void) const {
|
||||
return ring_prev->rank;
|
||||
}
|
||||
/*! \brief get rank */
|
||||
virtual int GetRank(void) const {
|
||||
return rank;
|
||||
@ -78,6 +82,35 @@ class AllreduceBase : public IEngine {
|
||||
virtual std::string GetHost(void) const {
|
||||
return host_uri;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Allgather(void *sendrecvbuf_, size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER) {
|
||||
if (world_size == 1 || world_size == -1) return;
|
||||
utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size,
|
||||
slice_begin, slice_end, size_prev_slice) == kSuccess,
|
||||
"AllgatherRing failed");
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
@ -550,6 +583,8 @@ class AllreduceBase : public IEngine {
|
||||
int timeout_sec = 1800;
|
||||
// flag to enable rabit_timeout
|
||||
bool rabit_timeout = false;
|
||||
// Enable TCP node delay
|
||||
bool rabit_enable_tcp_no_delay = false;
|
||||
};
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
|
||||
@ -25,6 +25,7 @@ class AllreduceMock : public AllreduceRobust {
|
||||
force_local = 0;
|
||||
report_stats = 0;
|
||||
tsum_allreduce = 0.0;
|
||||
tsum_allgather = 0.0;
|
||||
}
|
||||
// destructor
|
||||
virtual ~AllreduceMock(void) {}
|
||||
@ -59,6 +60,20 @@ class AllreduceMock : public AllreduceRobust {
|
||||
_file, _line, _caller);
|
||||
tsum_allreduce += utils::GetTime() - tstart;
|
||||
}
|
||||
virtual void Allgather(void *sendrecvbuf,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Allgather");
|
||||
double tstart = utils::GetTime();
|
||||
AllreduceRobust::Allgather(sendrecvbuf, total_size,
|
||||
slice_begin, slice_end, size_prev_slice);
|
||||
tsum_allgather += utils::GetTime() - tstart;
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
@ -69,6 +84,7 @@ class AllreduceMock : public AllreduceRobust {
|
||||
virtual int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model) {
|
||||
tsum_allreduce = 0.0;
|
||||
tsum_allgather = 0.0;
|
||||
time_checkpoint = utils::GetTime();
|
||||
if (force_local == 0) {
|
||||
return AllreduceRobust::LoadCheckPoint(global_model, local_model);
|
||||
@ -98,10 +114,12 @@ class AllreduceMock : public AllreduceRobust {
|
||||
<< ",local_size=" << (local_chkpt[0].length() + local_chkpt[1].length())
|
||||
<< ",check_tcost="<< tcost <<" sec"
|
||||
<< ",allreduce_tcost=" << tsum_allreduce << " sec"
|
||||
<< ",allgather_tcost=" << tsum_allgather << " sec"
|
||||
<< ",between_chpt=" << tbet_chkpt << "sec\n";
|
||||
this->TrackerPrint(ss.str());
|
||||
}
|
||||
tsum_allreduce = 0.0;
|
||||
tsum_allgather = 0.0;
|
||||
}
|
||||
|
||||
virtual void LazyCheckPoint(const Serializable *global_model) {
|
||||
@ -116,6 +134,8 @@ class AllreduceMock : public AllreduceRobust {
|
||||
int report_stats;
|
||||
// sum of allreduce
|
||||
double tsum_allreduce;
|
||||
// sum of allgather
|
||||
double tsum_allgather;
|
||||
double time_checkpoint;
|
||||
|
||||
private:
|
||||
|
||||
@ -143,13 +143,84 @@ int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf,
|
||||
|
||||
size_t siz = 0;
|
||||
void* temp = cachebuf.Query(index, &siz);
|
||||
_assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index");
|
||||
_assert(siz == type_nbytes*count, "cache size stored expected to be same as requested");
|
||||
_assert(siz > 0, "cache size should be greater than 0");
|
||||
utils::Assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index");
|
||||
utils::Assert(siz == type_nbytes*count, "cache size stored expected to be same as requested");
|
||||
utils::Assert(siz > 0, "cache size should be greater than 0");
|
||||
std::memcpy(buf, temp, type_nbytes*count);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
void AllreduceRobust::Allgather(void *sendrecvbuf,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
if (world_size == 1 || world_size == -1) return;
|
||||
// genreate unique allgather signature
|
||||
std::string key = std::string(_file) + "::" + std::to_string(_line) + "::"
|
||||
+ std::string(_caller) + "#" +std::to_string(total_size);
|
||||
|
||||
// try fetch bootstrap allgather results from cache
|
||||
if (!checkpoint_loaded && rabit_bootstrap_cache &&
|
||||
GetBootstrapCache(key, sendrecvbuf, total_size, 1) != -1) return;
|
||||
|
||||
double start = utils::GetTime();
|
||||
bool recovered = RecoverExec(sendrecvbuf, total_size, 0, seq_counter, cur_cache_seq);
|
||||
|
||||
if (resbuf.LastSeqNo() != -1 &&
|
||||
(result_buffer_round == -1 ||
|
||||
resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
|
||||
resbuf.DropLast();
|
||||
}
|
||||
|
||||
void *temp = resbuf.AllocTemp(total_size, 1);
|
||||
while (true) {
|
||||
if (recovered) {
|
||||
std::memcpy(temp, sendrecvbuf, total_size); break;
|
||||
} else {
|
||||
std::memcpy(temp, sendrecvbuf, total_size);
|
||||
if (CheckAndRecover(TryAllgatherRing(temp, total_size,
|
||||
slice_begin, slice_end, size_prev_slice))) {
|
||||
std::memcpy(sendrecvbuf, temp, total_size); break;
|
||||
} else {
|
||||
recovered = RecoverExec(sendrecvbuf, total_size, 0, seq_counter, cur_cache_seq);
|
||||
}
|
||||
}
|
||||
}
|
||||
double delta = utils::GetTime() - start;
|
||||
// log allgather latency
|
||||
if (rabit_debug) {
|
||||
utils::HandleLogInfo("[%d] allgather (%s) finished version %d, seq %d, take %f seconds\n",
|
||||
rank, key.c_str(), version_number, seq_counter, delta);
|
||||
}
|
||||
|
||||
// if bootstrap allgather, store and fetch through cache
|
||||
if (checkpoint_loaded || !rabit_bootstrap_cache) {
|
||||
resbuf.PushTemp(seq_counter, total_size, 1);
|
||||
seq_counter += 1;
|
||||
} else {
|
||||
SetBootstrapCache(key, sendrecvbuf, total_size, 1);
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
|
||||
@ -51,6 +51,29 @@ class AllreduceRobust : public AllreduceBase {
|
||||
*/
|
||||
int GetBootstrapCache(const std::string &key, void *buf, const size_t type_nbytes,
|
||||
const size_t count);
|
||||
/*!
|
||||
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Allgather(void *sendrecvbuf_, size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
|
||||
86
src/c_api.cc
86
src/c_api.cc
@ -13,7 +13,7 @@ namespace c_api {
|
||||
// helper use to avoid BitOR operator
|
||||
template<typename OP, typename DType>
|
||||
struct FHelper {
|
||||
inline static void
|
||||
static void
|
||||
Allreduce(DType *senrecvbuf_,
|
||||
size_t count,
|
||||
void (*prepare_fun)(void *arg),
|
||||
@ -25,7 +25,7 @@ struct FHelper {
|
||||
|
||||
template<typename DType>
|
||||
struct FHelper<op::BitOR, DType> {
|
||||
inline static void
|
||||
static void
|
||||
Allreduce(DType *senrecvbuf_,
|
||||
size_t count,
|
||||
void (*prepare_fun)(void *arg),
|
||||
@ -35,11 +35,11 @@ struct FHelper<op::BitOR, DType> {
|
||||
};
|
||||
|
||||
template<typename OP>
|
||||
inline void Allreduce_(void *sendrecvbuf_,
|
||||
size_t count,
|
||||
engine::mpi::DataType enum_dtype,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
void Allreduce_(void *sendrecvbuf_,
|
||||
size_t count,
|
||||
engine::mpi::DataType enum_dtype,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
using namespace engine::mpi;
|
||||
switch (enum_dtype) {
|
||||
case kChar:
|
||||
@ -85,12 +85,12 @@ inline void Allreduce_(void *sendrecvbuf_,
|
||||
default: utils::Error("unknown data_type");
|
||||
}
|
||||
}
|
||||
inline void Allreduce(void *sendrecvbuf,
|
||||
size_t count,
|
||||
engine::mpi::DataType enum_dtype,
|
||||
engine::mpi::OpType enum_op,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
void Allreduce(void *sendrecvbuf,
|
||||
size_t count,
|
||||
engine::mpi::DataType enum_dtype,
|
||||
engine::mpi::OpType enum_op,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
using namespace engine::mpi;
|
||||
switch (enum_op) {
|
||||
case kMax:
|
||||
@ -120,8 +120,45 @@ inline void Allreduce(void *sendrecvbuf,
|
||||
default: utils::Error("unknown enum_op");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void Allgather(void *sendrecvbuf_,
|
||||
size_t total_size,
|
||||
size_t beginIndex,
|
||||
size_t size_node_slice,
|
||||
size_t size_prev_slice,
|
||||
int enum_dtype) {
|
||||
using namespace engine::mpi;
|
||||
size_t type_size = 0;
|
||||
switch (enum_dtype) {
|
||||
case kChar:
|
||||
type_size = sizeof(char);
|
||||
break;
|
||||
case kUChar:
|
||||
type_size = sizeof(unsigned char);
|
||||
break;
|
||||
case kInt:
|
||||
type_size = sizeof(int);
|
||||
break;
|
||||
case kUInt:
|
||||
type_size = sizeof(unsigned);
|
||||
break;
|
||||
case kLong:
|
||||
type_size = sizeof(int64_t);
|
||||
break;
|
||||
case kULong:
|
||||
type_size = sizeof(uint64_t);
|
||||
break;
|
||||
case kFloat:
|
||||
type_size = sizeof(float);
|
||||
break;
|
||||
case kDouble:
|
||||
type_size = sizeof(double);
|
||||
break;
|
||||
default: utils::Error("unknown data_type");
|
||||
}
|
||||
engine::Allgather
|
||||
(sendrecvbuf_, total_size * type_size, beginIndex * type_size,
|
||||
(beginIndex + size_node_slice) * type_size, size_prev_slice * type_size);
|
||||
}
|
||||
|
||||
// wrapper for serialization
|
||||
struct ReadWrapper : public Serializable {
|
||||
@ -170,6 +207,10 @@ bool RabitFinalize() {
|
||||
return rabit::Finalize();
|
||||
}
|
||||
|
||||
int RabitGetRingPrevRank() {
|
||||
return rabit::GetRingPrevRank();
|
||||
}
|
||||
|
||||
int RabitGetRank() {
|
||||
return rabit::GetRank();
|
||||
}
|
||||
@ -203,6 +244,21 @@ void RabitBroadcast(void *sendrecv_data,
|
||||
rabit::Broadcast(sendrecv_data, size, root);
|
||||
}
|
||||
|
||||
void RabitAllgather(void *sendrecvbuf_,
|
||||
size_t total_size,
|
||||
size_t beginIndex,
|
||||
size_t size_node_slice,
|
||||
size_t size_prev_slice,
|
||||
int enum_dtype) {
|
||||
rabit::c_api::Allgather(sendrecvbuf_,
|
||||
total_size,
|
||||
beginIndex,
|
||||
size_node_slice,
|
||||
size_prev_slice,
|
||||
enum_dtype);
|
||||
}
|
||||
|
||||
|
||||
void RabitAllreduce(void *sendrecvbuf,
|
||||
size_t count,
|
||||
int enum_dtype,
|
||||
|
||||
@ -84,6 +84,15 @@ IEngine *GetEngine() {
|
||||
}
|
||||
}
|
||||
|
||||
// perform in-place allgather, on sendrecvbuf
|
||||
void Allgather(void *sendrecvbuf_, size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice) {
|
||||
GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin, slice_end, size_prev_slice);
|
||||
}
|
||||
|
||||
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
void Allreduce_(void *sendrecvbuf,
|
||||
size_t type_nbytes,
|
||||
|
||||
@ -25,6 +25,19 @@ class EmptyEngine : public IEngine {
|
||||
EmptyEngine(void) {
|
||||
version_number = 0;
|
||||
}
|
||||
virtual void Allgather(void *sendrecvbuf_,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
utils::Error("EmptyEngine:: Allgather is not supported");
|
||||
}
|
||||
virtual int GetRingPrevRank(void) const {
|
||||
utils::Error("EmptyEngine:: GetRingPrevRank is not supported");
|
||||
}
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
|
||||
@ -27,6 +27,16 @@ class MPIEngine : public IEngine {
|
||||
MPIEngine(void) {
|
||||
version_number = 0;
|
||||
}
|
||||
virtual void Allgather(void *sendrecvbuf_,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
utils::Error("MPIEngine:: Allgather is not supported");
|
||||
}
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
@ -39,6 +49,9 @@ class MPIEngine : public IEngine {
|
||||
utils::Error("MPIEngine:: Allreduce is not supported,"\
|
||||
"use Allreduce_ instead");
|
||||
}
|
||||
virtual int GetRingPrevRank(void) const {
|
||||
utils::Error("MPIEngine:: GetRingPrevRank is not supported");
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root,
|
||||
const char* _file, const int _line,
|
||||
const char* _caller) {
|
||||
|
||||
66
test/cpp/allreduce_base_test.cpp
Normal file
66
test/cpp/allreduce_base_test.cpp
Normal file
@ -0,0 +1,66 @@
|
||||
#define RABIT_CXXTESTDEFS_H
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include "../../src/allreduce_base.h"
|
||||
|
||||
TEST(allreduce_base, init_task)
|
||||
{
|
||||
rabit::engine::AllreduceBase base;
|
||||
|
||||
std::string rabit_task_id = "rabit_task_id=1";
|
||||
char cmd[rabit_task_id.size()+1];
|
||||
std::copy(rabit_task_id.begin(), rabit_task_id.end(), cmd);
|
||||
cmd[rabit_task_id.size()] = '\0';
|
||||
|
||||
char* argv[] = {cmd};
|
||||
base.Init(1, argv);
|
||||
EXPECT_EQ(base.task_id, "1");
|
||||
}
|
||||
|
||||
TEST(allreduce_base, init_with_cache_on)
|
||||
{
|
||||
rabit::engine::AllreduceBase base;
|
||||
|
||||
std::string rabit_task_id = "rabit_task_id=1";
|
||||
char cmd[rabit_task_id.size()+1];
|
||||
std::copy(rabit_task_id.begin(), rabit_task_id.end(), cmd);
|
||||
cmd[rabit_task_id.size()] = '\0';
|
||||
|
||||
std::string rabit_bootstrap_cache = "rabit_bootstrap_cache=1";
|
||||
char cmd2[rabit_bootstrap_cache.size()+1];
|
||||
std::copy(rabit_bootstrap_cache.begin(), rabit_bootstrap_cache.end(), cmd2);
|
||||
cmd2[rabit_bootstrap_cache.size()] = '\0';
|
||||
|
||||
std::string rabit_debug = "rabit_debug=1";
|
||||
char cmd3[rabit_debug.size()+1];
|
||||
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd3);
|
||||
cmd3[rabit_debug.size()] = '\0';
|
||||
|
||||
char* argv[] = {cmd, cmd2, cmd3};
|
||||
base.Init(3, argv);
|
||||
EXPECT_EQ(base.task_id, "1");
|
||||
EXPECT_EQ(base.rabit_bootstrap_cache, 1);
|
||||
EXPECT_EQ(base.rabit_debug, 1);
|
||||
}
|
||||
|
||||
TEST(allreduce_base, init_with_ring_reduce)
|
||||
{
|
||||
rabit::engine::AllreduceBase base;
|
||||
|
||||
std::string rabit_task_id = "rabit_task_id=1";
|
||||
char cmd[rabit_task_id.size()+1];
|
||||
std::copy(rabit_task_id.begin(), rabit_task_id.end(), cmd);
|
||||
cmd[rabit_task_id.size()] = '\0';
|
||||
|
||||
std::string rabit_reduce_ring_mincount = "rabit_reduce_ring_mincount=1";
|
||||
char cmd2[rabit_reduce_ring_mincount.size()+1];
|
||||
std::copy(rabit_reduce_ring_mincount.begin(), rabit_reduce_ring_mincount.end(), cmd2);
|
||||
cmd2[rabit_reduce_ring_mincount.size()] = '\0';
|
||||
|
||||
char* argv[] = {cmd, cmd2};
|
||||
base.Init(2, argv);
|
||||
EXPECT_EQ(base.task_id, "1");
|
||||
EXPECT_EQ(base.reduce_ring_mincount, 1);
|
||||
}
|
||||
51
test/cpp/allreduce_mock_test.cpp
Normal file
51
test/cpp/allreduce_mock_test.cpp
Normal file
@ -0,0 +1,51 @@
|
||||
#define RABIT_CXXTESTDEFS_H
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include "../../src/allreduce_mock.h"
|
||||
|
||||
TEST(allreduce_mock, mock_allreduce)
|
||||
{
|
||||
rabit::engine::AllreduceMock m;
|
||||
|
||||
std::string mock_str = "mock=0,0,0,0";
|
||||
char cmd[mock_str.size()+1];
|
||||
std::copy(mock_str.begin(), mock_str.end(), cmd);
|
||||
cmd[mock_str.size()] = '\0';
|
||||
|
||||
char* argv[] = {cmd};
|
||||
m.Init(1, argv);
|
||||
m.rank = 0;
|
||||
EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), "");
|
||||
}
|
||||
|
||||
TEST(allreduce_mock, mock_broadcast)
|
||||
{
|
||||
rabit::engine::AllreduceMock m;
|
||||
std::string mock_str = "mock=0,1,2,0";
|
||||
char cmd[mock_str.size()+1];
|
||||
std::copy(mock_str.begin(), mock_str.end(), cmd);
|
||||
cmd[mock_str.size()] = '\0';
|
||||
char* argv[] = {cmd};
|
||||
m.Init(1, argv);
|
||||
m.rank = 0;
|
||||
m.version_number=1;
|
||||
m.seq_counter=2;
|
||||
EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), "");
|
||||
}
|
||||
|
||||
TEST(allreduce_mock, mock_gather)
|
||||
{
|
||||
rabit::engine::AllreduceMock m;
|
||||
std::string mock_str = "mock=3,13,22,0";
|
||||
char cmd[mock_str.size()+1];
|
||||
std::copy(mock_str.begin(), mock_str.end(), cmd);
|
||||
cmd[mock_str.size()] = '\0';
|
||||
char* argv[] = {cmd};
|
||||
m.Init(1, argv);
|
||||
m.rank = 3;
|
||||
m.version_number=13;
|
||||
m.seq_counter=22;
|
||||
EXPECT_EXIT(m.Allgather(nullptr,0,0,0,0), ::testing::ExitedWithCode(255), "");
|
||||
}
|
||||
@ -26,7 +26,7 @@ class Model : public rabit::Serializable {
|
||||
}
|
||||
};
|
||||
|
||||
inline void TestMax(Model *model, int ntrial, int iter) {
|
||||
inline void TestMax(Model *model, int iter) {
|
||||
int rank = rabit::GetRank();
|
||||
int nproc = rabit::GetWorldSize();
|
||||
const int z = iter + 111;
|
||||
@ -47,7 +47,7 @@ inline void TestMax(Model *model, int ntrial, int iter) {
|
||||
model->data = ndata;
|
||||
}
|
||||
|
||||
inline void TestSum(Model *model, int ntrial, int iter) {
|
||||
inline void TestSum(Model *model, int iter) {
|
||||
int rank = rabit::GetRank();
|
||||
int nproc = rabit::GetWorldSize();
|
||||
const int z = 131 + iter;
|
||||
@ -69,7 +69,30 @@ inline void TestSum(Model *model, int ntrial, int iter) {
|
||||
model->data = ndata;
|
||||
}
|
||||
|
||||
inline void TestBcast(size_t n, int root, int ntrial, int iter) {
|
||||
inline void TestAllgather(Model *model, int iter) {
|
||||
int rank = rabit::GetRank();
|
||||
int nproc = rabit::GetWorldSize();
|
||||
const int z = 131 + iter;
|
||||
|
||||
std::vector<float> ndata(model->data.size() * nproc);
|
||||
size_t beginSlice = rank * model->data.size();
|
||||
for (size_t i = 0; i < model->data.size(); ++i) {
|
||||
ndata[beginSlice + i] = (i * (rank+1)) % z + model->data[i];
|
||||
}
|
||||
Allgather(&ndata[0], ndata.size(), beginSlice,
|
||||
model->data.size(), model->data.size());
|
||||
|
||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||
int curRank = i / model->data.size();
|
||||
int remainder = i % model->data.size();
|
||||
float data = (remainder * (curRank+1)) % z + model->data[remainder];
|
||||
utils::Check(fabsf(data - ndata[i]) < 1e-5 ,
|
||||
"[%d] TestAllgather check failure, local=%g, allgatherring=%g", rank, data, ndata[i]);
|
||||
}
|
||||
model->data = ndata;
|
||||
}
|
||||
|
||||
inline void TestBcast(size_t n, int root) {
|
||||
int rank = rabit::GetRank();
|
||||
std::string s; s.resize(n);
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
@ -113,19 +136,22 @@ int main(int argc, char *argv[]) {
|
||||
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||
|
||||
for (int r = iter; r < 3; ++r) {
|
||||
TestMax(&model, ntrial, r);
|
||||
TestMax(&model, r);
|
||||
printf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
|
||||
int step = std::max(nproc / 3, 1);
|
||||
for (int i = 0; i < nproc; i += step) {
|
||||
TestBcast(n, i, ntrial, r);
|
||||
TestBcast(n, i);
|
||||
}
|
||||
printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
|
||||
|
||||
TestSum(&model, ntrial, r);
|
||||
TestSum(&model, r);
|
||||
printf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
|
||||
TestAllgather(&model, r);
|
||||
printf("[%d] !!!TestAllgather pass, iter=%d\n", rank, r);
|
||||
rabit::CheckPoint(&model);
|
||||
printf("[%d] !!!Checkpoint pass, iter=%d\n", rank, r);
|
||||
}
|
||||
rabit::Finalize();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user