Expose RabitAllGatherRing and RabitGetRingPrevRank (#113)
* add unittests * Expose RabitAllGatherRing and RabitGetRingPrevRank * Enabled TCP_NODELAY to decrease latency
This commit is contained in:
@@ -8,8 +8,9 @@
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#define NOMINMAX
|
||||
#include <map>
|
||||
#include <netinet/tcp.h>
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
#include "allreduce_base.h"
|
||||
|
||||
namespace rabit {
|
||||
@@ -221,6 +222,12 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
timeout_sec = atoi(val);
|
||||
utils::Assert(timeout_sec >= 0, "rabit_timeout_sec should be non negative second");
|
||||
}
|
||||
if (!strcmp(name, "rabit_enable_tcp_no_delay")) {
|
||||
if (!strcmp(val, "true"))
|
||||
rabit_enable_tcp_no_delay = true;
|
||||
else
|
||||
rabit_enable_tcp_no_delay = false;
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief initialize connection to the tracker
|
||||
@@ -420,11 +427,16 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
this->parent_index = -1;
|
||||
// setup tree links and ring structure
|
||||
tree_links.plinks.clear();
|
||||
int tcpNoDelay = 1;
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
|
||||
// set the socket to non-blocking mode, enable TCP keepalive
|
||||
all_links[i].sock.SetNonBlock(true);
|
||||
all_links[i].sock.SetKeepAlive(true);
|
||||
if (rabit_enable_tcp_no_delay) {
|
||||
setsockopt(all_links[i].sock, IPPROTO_TCP,
|
||||
TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
|
||||
}
|
||||
if (tree_neighbors.count(all_links[i].rank) != 0) {
|
||||
if (all_links[i].rank == parent_rank) {
|
||||
parent_index = static_cast<int>(tree_links.plinks.size());
|
||||
|
||||
@@ -61,6 +61,10 @@ class AllreduceBase : public IEngine {
|
||||
*/
|
||||
virtual void TrackerPrint(const std::string &msg);
|
||||
|
||||
/*! \brief get rank of previous node in ring topology*/
|
||||
virtual int GetRingPrevRank(void) const {
|
||||
return ring_prev->rank;
|
||||
}
|
||||
/*! \brief get rank */
|
||||
virtual int GetRank(void) const {
|
||||
return rank;
|
||||
@@ -78,6 +82,35 @@ class AllreduceBase : public IEngine {
|
||||
virtual std::string GetHost(void) const {
|
||||
return host_uri;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Allgather(void *sendrecvbuf_, size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER) {
|
||||
if (world_size == 1 || world_size == -1) return;
|
||||
utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size,
|
||||
slice_begin, slice_end, size_prev_slice) == kSuccess,
|
||||
"AllgatherRing failed");
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
@@ -550,6 +583,8 @@ class AllreduceBase : public IEngine {
|
||||
int timeout_sec = 1800;
|
||||
// flag to enable rabit_timeout
|
||||
bool rabit_timeout = false;
|
||||
// Enable TCP node delay
|
||||
bool rabit_enable_tcp_no_delay = false;
|
||||
};
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
|
||||
@@ -25,6 +25,7 @@ class AllreduceMock : public AllreduceRobust {
|
||||
force_local = 0;
|
||||
report_stats = 0;
|
||||
tsum_allreduce = 0.0;
|
||||
tsum_allgather = 0.0;
|
||||
}
|
||||
// destructor
|
||||
virtual ~AllreduceMock(void) {}
|
||||
@@ -59,6 +60,20 @@ class AllreduceMock : public AllreduceRobust {
|
||||
_file, _line, _caller);
|
||||
tsum_allreduce += utils::GetTime() - tstart;
|
||||
}
|
||||
virtual void Allgather(void *sendrecvbuf,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Allgather");
|
||||
double tstart = utils::GetTime();
|
||||
AllreduceRobust::Allgather(sendrecvbuf, total_size,
|
||||
slice_begin, slice_end, size_prev_slice);
|
||||
tsum_allgather += utils::GetTime() - tstart;
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
@@ -69,6 +84,7 @@ class AllreduceMock : public AllreduceRobust {
|
||||
virtual int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model) {
|
||||
tsum_allreduce = 0.0;
|
||||
tsum_allgather = 0.0;
|
||||
time_checkpoint = utils::GetTime();
|
||||
if (force_local == 0) {
|
||||
return AllreduceRobust::LoadCheckPoint(global_model, local_model);
|
||||
@@ -98,10 +114,12 @@ class AllreduceMock : public AllreduceRobust {
|
||||
<< ",local_size=" << (local_chkpt[0].length() + local_chkpt[1].length())
|
||||
<< ",check_tcost="<< tcost <<" sec"
|
||||
<< ",allreduce_tcost=" << tsum_allreduce << " sec"
|
||||
<< ",allgather_tcost=" << tsum_allgather << " sec"
|
||||
<< ",between_chpt=" << tbet_chkpt << "sec\n";
|
||||
this->TrackerPrint(ss.str());
|
||||
}
|
||||
tsum_allreduce = 0.0;
|
||||
tsum_allgather = 0.0;
|
||||
}
|
||||
|
||||
virtual void LazyCheckPoint(const Serializable *global_model) {
|
||||
@@ -116,6 +134,8 @@ class AllreduceMock : public AllreduceRobust {
|
||||
int report_stats;
|
||||
// sum of allreduce
|
||||
double tsum_allreduce;
|
||||
// sum of allgather
|
||||
double tsum_allgather;
|
||||
double time_checkpoint;
|
||||
|
||||
private:
|
||||
|
||||
@@ -143,13 +143,84 @@ int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf,
|
||||
|
||||
size_t siz = 0;
|
||||
void* temp = cachebuf.Query(index, &siz);
|
||||
_assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index");
|
||||
_assert(siz == type_nbytes*count, "cache size stored expected to be same as requested");
|
||||
_assert(siz > 0, "cache size should be greater than 0");
|
||||
utils::Assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index");
|
||||
utils::Assert(siz == type_nbytes*count, "cache size stored expected to be same as requested");
|
||||
utils::Assert(siz > 0, "cache size should be greater than 0");
|
||||
std::memcpy(buf, temp, type_nbytes*count);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
void AllreduceRobust::Allgather(void *sendrecvbuf,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
if (world_size == 1 || world_size == -1) return;
|
||||
// genreate unique allgather signature
|
||||
std::string key = std::string(_file) + "::" + std::to_string(_line) + "::"
|
||||
+ std::string(_caller) + "#" +std::to_string(total_size);
|
||||
|
||||
// try fetch bootstrap allgather results from cache
|
||||
if (!checkpoint_loaded && rabit_bootstrap_cache &&
|
||||
GetBootstrapCache(key, sendrecvbuf, total_size, 1) != -1) return;
|
||||
|
||||
double start = utils::GetTime();
|
||||
bool recovered = RecoverExec(sendrecvbuf, total_size, 0, seq_counter, cur_cache_seq);
|
||||
|
||||
if (resbuf.LastSeqNo() != -1 &&
|
||||
(result_buffer_round == -1 ||
|
||||
resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
|
||||
resbuf.DropLast();
|
||||
}
|
||||
|
||||
void *temp = resbuf.AllocTemp(total_size, 1);
|
||||
while (true) {
|
||||
if (recovered) {
|
||||
std::memcpy(temp, sendrecvbuf, total_size); break;
|
||||
} else {
|
||||
std::memcpy(temp, sendrecvbuf, total_size);
|
||||
if (CheckAndRecover(TryAllgatherRing(temp, total_size,
|
||||
slice_begin, slice_end, size_prev_slice))) {
|
||||
std::memcpy(sendrecvbuf, temp, total_size); break;
|
||||
} else {
|
||||
recovered = RecoverExec(sendrecvbuf, total_size, 0, seq_counter, cur_cache_seq);
|
||||
}
|
||||
}
|
||||
}
|
||||
double delta = utils::GetTime() - start;
|
||||
// log allgather latency
|
||||
if (rabit_debug) {
|
||||
utils::HandleLogInfo("[%d] allgather (%s) finished version %d, seq %d, take %f seconds\n",
|
||||
rank, key.c_str(), version_number, seq_counter, delta);
|
||||
}
|
||||
|
||||
// if bootstrap allgather, store and fetch through cache
|
||||
if (checkpoint_loaded || !rabit_bootstrap_cache) {
|
||||
resbuf.PushTemp(seq_counter, total_size, 1);
|
||||
seq_counter += 1;
|
||||
} else {
|
||||
SetBootstrapCache(key, sendrecvbuf, total_size, 1);
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
|
||||
@@ -51,6 +51,29 @@ class AllreduceRobust : public AllreduceBase {
|
||||
*/
|
||||
int GetBootstrapCache(const std::string &key, void *buf, const size_t type_nbytes,
|
||||
const size_t count);
|
||||
/*!
|
||||
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Allgather(void *sendrecvbuf_, size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
|
||||
86
src/c_api.cc
86
src/c_api.cc
@@ -13,7 +13,7 @@ namespace c_api {
|
||||
// helper use to avoid BitOR operator
|
||||
template<typename OP, typename DType>
|
||||
struct FHelper {
|
||||
inline static void
|
||||
static void
|
||||
Allreduce(DType *senrecvbuf_,
|
||||
size_t count,
|
||||
void (*prepare_fun)(void *arg),
|
||||
@@ -25,7 +25,7 @@ struct FHelper {
|
||||
|
||||
template<typename DType>
|
||||
struct FHelper<op::BitOR, DType> {
|
||||
inline static void
|
||||
static void
|
||||
Allreduce(DType *senrecvbuf_,
|
||||
size_t count,
|
||||
void (*prepare_fun)(void *arg),
|
||||
@@ -35,11 +35,11 @@ struct FHelper<op::BitOR, DType> {
|
||||
};
|
||||
|
||||
template<typename OP>
|
||||
inline void Allreduce_(void *sendrecvbuf_,
|
||||
size_t count,
|
||||
engine::mpi::DataType enum_dtype,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
void Allreduce_(void *sendrecvbuf_,
|
||||
size_t count,
|
||||
engine::mpi::DataType enum_dtype,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
using namespace engine::mpi;
|
||||
switch (enum_dtype) {
|
||||
case kChar:
|
||||
@@ -85,12 +85,12 @@ inline void Allreduce_(void *sendrecvbuf_,
|
||||
default: utils::Error("unknown data_type");
|
||||
}
|
||||
}
|
||||
inline void Allreduce(void *sendrecvbuf,
|
||||
size_t count,
|
||||
engine::mpi::DataType enum_dtype,
|
||||
engine::mpi::OpType enum_op,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
void Allreduce(void *sendrecvbuf,
|
||||
size_t count,
|
||||
engine::mpi::DataType enum_dtype,
|
||||
engine::mpi::OpType enum_op,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
using namespace engine::mpi;
|
||||
switch (enum_op) {
|
||||
case kMax:
|
||||
@@ -120,8 +120,45 @@ inline void Allreduce(void *sendrecvbuf,
|
||||
default: utils::Error("unknown enum_op");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void Allgather(void *sendrecvbuf_,
|
||||
size_t total_size,
|
||||
size_t beginIndex,
|
||||
size_t size_node_slice,
|
||||
size_t size_prev_slice,
|
||||
int enum_dtype) {
|
||||
using namespace engine::mpi;
|
||||
size_t type_size = 0;
|
||||
switch (enum_dtype) {
|
||||
case kChar:
|
||||
type_size = sizeof(char);
|
||||
break;
|
||||
case kUChar:
|
||||
type_size = sizeof(unsigned char);
|
||||
break;
|
||||
case kInt:
|
||||
type_size = sizeof(int);
|
||||
break;
|
||||
case kUInt:
|
||||
type_size = sizeof(unsigned);
|
||||
break;
|
||||
case kLong:
|
||||
type_size = sizeof(int64_t);
|
||||
break;
|
||||
case kULong:
|
||||
type_size = sizeof(uint64_t);
|
||||
break;
|
||||
case kFloat:
|
||||
type_size = sizeof(float);
|
||||
break;
|
||||
case kDouble:
|
||||
type_size = sizeof(double);
|
||||
break;
|
||||
default: utils::Error("unknown data_type");
|
||||
}
|
||||
engine::Allgather
|
||||
(sendrecvbuf_, total_size * type_size, beginIndex * type_size,
|
||||
(beginIndex + size_node_slice) * type_size, size_prev_slice * type_size);
|
||||
}
|
||||
|
||||
// wrapper for serialization
|
||||
struct ReadWrapper : public Serializable {
|
||||
@@ -170,6 +207,10 @@ bool RabitFinalize() {
|
||||
return rabit::Finalize();
|
||||
}
|
||||
|
||||
int RabitGetRingPrevRank() {
|
||||
return rabit::GetRingPrevRank();
|
||||
}
|
||||
|
||||
int RabitGetRank() {
|
||||
return rabit::GetRank();
|
||||
}
|
||||
@@ -203,6 +244,21 @@ void RabitBroadcast(void *sendrecv_data,
|
||||
rabit::Broadcast(sendrecv_data, size, root);
|
||||
}
|
||||
|
||||
void RabitAllgather(void *sendrecvbuf_,
|
||||
size_t total_size,
|
||||
size_t beginIndex,
|
||||
size_t size_node_slice,
|
||||
size_t size_prev_slice,
|
||||
int enum_dtype) {
|
||||
rabit::c_api::Allgather(sendrecvbuf_,
|
||||
total_size,
|
||||
beginIndex,
|
||||
size_node_slice,
|
||||
size_prev_slice,
|
||||
enum_dtype);
|
||||
}
|
||||
|
||||
|
||||
void RabitAllreduce(void *sendrecvbuf,
|
||||
size_t count,
|
||||
int enum_dtype,
|
||||
|
||||
@@ -84,6 +84,15 @@ IEngine *GetEngine() {
|
||||
}
|
||||
}
|
||||
|
||||
// perform in-place allgather, on sendrecvbuf
|
||||
void Allgather(void *sendrecvbuf_, size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice) {
|
||||
GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin, slice_end, size_prev_slice);
|
||||
}
|
||||
|
||||
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
void Allreduce_(void *sendrecvbuf,
|
||||
size_t type_nbytes,
|
||||
|
||||
@@ -25,6 +25,19 @@ class EmptyEngine : public IEngine {
|
||||
EmptyEngine(void) {
|
||||
version_number = 0;
|
||||
}
|
||||
virtual void Allgather(void *sendrecvbuf_,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
utils::Error("EmptyEngine:: Allgather is not supported");
|
||||
}
|
||||
virtual int GetRingPrevRank(void) const {
|
||||
utils::Error("EmptyEngine:: GetRingPrevRank is not supported");
|
||||
}
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
|
||||
@@ -27,6 +27,16 @@ class MPIEngine : public IEngine {
|
||||
MPIEngine(void) {
|
||||
version_number = 0;
|
||||
}
|
||||
virtual void Allgather(void *sendrecvbuf_,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
utils::Error("MPIEngine:: Allgather is not supported");
|
||||
}
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
@@ -39,6 +49,9 @@ class MPIEngine : public IEngine {
|
||||
utils::Error("MPIEngine:: Allreduce is not supported,"\
|
||||
"use Allreduce_ instead");
|
||||
}
|
||||
virtual int GetRingPrevRank(void) const {
|
||||
utils::Error("MPIEngine:: GetRingPrevRank is not supported");
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root,
|
||||
const char* _file, const int _line,
|
||||
const char* _caller) {
|
||||
|
||||
Reference in New Issue
Block a user