change AllReduce to Allreduce
This commit is contained in:
parent
8cb5b68cb6
commit
ed1de6df80
@ -13,7 +13,7 @@
|
|||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
// constructor
|
// constructor
|
||||||
AllReduceBase::AllReduceBase(void) {
|
AllreduceBase::AllreduceBase(void) {
|
||||||
master_uri = "NULL";
|
master_uri = "NULL";
|
||||||
master_port = 9000;
|
master_port = 9000;
|
||||||
host_uri = "";
|
host_uri = "";
|
||||||
@ -26,7 +26,7 @@ AllReduceBase::AllReduceBase(void) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// initialization function
|
// initialization function
|
||||||
void AllReduceBase::Init(void) {
|
void AllreduceBase::Init(void) {
|
||||||
utils::Socket::Startup();
|
utils::Socket::Startup();
|
||||||
// single node mode
|
// single node mode
|
||||||
if (master_uri == "NULL") return;
|
if (master_uri == "NULL") return;
|
||||||
@ -68,7 +68,7 @@ void AllReduceBase::Init(void) {
|
|||||||
utils::Assert(master.RecvAll(&hname[0], len) == static_cast<size_t>(len), "sync::Init failure 10");
|
utils::Assert(master.RecvAll(&hname[0], len) == static_cast<size_t>(len), "sync::Init failure 10");
|
||||||
utils::Assert(master.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "sync::Init failure 11");
|
utils::Assert(master.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "sync::Init failure 11");
|
||||||
links[0].sock.Create();
|
links[0].sock.Create();
|
||||||
links[0].sock.Connect(utils::SockAddr(hname.c_str(), hport));
|
links[0].sock.Connect(utils::SockAddr(hname.c_str(), hport));
|
||||||
utils::Assert(links[0].sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 12");
|
utils::Assert(links[0].sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 12");
|
||||||
utils::Assert(links[0].sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 13");
|
utils::Assert(links[0].sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 13");
|
||||||
utils::Check(magic == kMagic, "sync::Init failure, parent magic number mismatch");
|
utils::Check(magic == kMagic, "sync::Init failure, parent magic number mismatch");
|
||||||
@ -105,7 +105,7 @@ void AllReduceBase::Init(void) {
|
|||||||
// done
|
// done
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllReduceBase::Shutdown(void) {
|
void AllreduceBase::Shutdown(void) {
|
||||||
for (size_t i = 0; i < links.size(); ++i) {
|
for (size_t i = 0; i < links.size(); ++i) {
|
||||||
links[i].sock.Close();
|
links[i].sock.Close();
|
||||||
}
|
}
|
||||||
@ -117,7 +117,7 @@ void AllReduceBase::Shutdown(void) {
|
|||||||
* \param name parameter name
|
* \param name parameter name
|
||||||
* \param val parameter value
|
* \param val parameter value
|
||||||
*/
|
*/
|
||||||
void AllReduceBase::SetParam(const char *name, const char *val) {
|
void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||||
if (!strcmp(name, "master_uri")) master_uri = val;
|
if (!strcmp(name, "master_uri")) master_uri = val;
|
||||||
if (!strcmp(name, "master_port")) master_port = atoi(val);
|
if (!strcmp(name, "master_port")) master_port = atoi(val);
|
||||||
if (!strcmp(name, "reduce_buffer")) {
|
if (!strcmp(name, "reduce_buffer")) {
|
||||||
@ -140,10 +140,10 @@ void AllReduceBase::SetParam(const char *name, const char *val) {
|
|||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
||||||
*
|
*
|
||||||
* NOTE on AllReduce:
|
* NOTE on Allreduce:
|
||||||
* The kSuccess TryAllReduce does NOT mean every node have successfully finishes TryAllReduce.
|
* The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
|
||||||
* It only means the current node get the correct result of AllReduce.
|
* It only means the current node get the correct result of Allreduce.
|
||||||
* However, it means every node finishes LAST call(instead of this one) of AllReduce/Bcast
|
* However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
|
||||||
*
|
*
|
||||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
* \param type_nbytes the unit number of bytes the type have
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
@ -152,8 +152,8 @@ void AllReduceBase::SetParam(const char *name, const char *val) {
|
|||||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
* \sa ReturnType
|
* \sa ReturnType
|
||||||
*/
|
*/
|
||||||
AllReduceBase::ReturnType
|
AllreduceBase::ReturnType
|
||||||
AllReduceBase::TryAllReduce(void *sendrecvbuf_,
|
AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer) {
|
ReduceFunction reducer) {
|
||||||
@ -248,7 +248,7 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_,
|
|||||||
size_t start = size_up_reduce % buffer_size;
|
size_t start = size_up_reduce % buffer_size;
|
||||||
// peform read till end of buffer
|
// peform read till end of buffer
|
||||||
size_t nread = std::min(buffer_size - start, max_reduce - size_up_reduce);
|
size_t nread = std::min(buffer_size - start, max_reduce - size_up_reduce);
|
||||||
utils::Assert(nread % type_nbytes == 0, "AllReduce: size check");
|
utils::Assert(nread % type_nbytes == 0, "Allreduce: size check");
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (i != parent_index) {
|
if (i != parent_index) {
|
||||||
reducer(links[i].buffer_head + start,
|
reducer(links[i].buffer_head + start,
|
||||||
@ -280,7 +280,7 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_,
|
|||||||
}
|
}
|
||||||
if (len != -1) {
|
if (len != -1) {
|
||||||
size_down_in += static_cast<size_t>(len);
|
size_down_in += static_cast<size_t>(len);
|
||||||
utils::Assert(size_down_in <= size_up_out, "AllReduce: boundary error");
|
utils::Assert(size_down_in <= size_up_out, "Allreduce: boundary error");
|
||||||
} else {
|
} else {
|
||||||
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
|
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
|
||||||
}
|
}
|
||||||
@ -306,8 +306,8 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_,
|
|||||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
* \sa ReturnType
|
* \sa ReturnType
|
||||||
*/
|
*/
|
||||||
AllReduceBase::ReturnType
|
AllreduceBase::ReturnType
|
||||||
AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||||
if (links.size() == 0 || total_size == 0) return kSuccess;
|
if (links.size() == 0 || total_size == 0) return kSuccess;
|
||||||
utils::Check(root < world_size, "Broadcast: root should be smaller than world size");
|
utils::Check(root < world_size, "Broadcast: root should be smaller than world size");
|
||||||
// number of links
|
// number of links
|
||||||
|
|||||||
@ -27,14 +27,14 @@ class Datatype {
|
|||||||
}
|
}
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
/*! \brief implementation of basic AllReduce engine */
|
/*! \brief implementation of basic Allreduce engine */
|
||||||
class AllReduceBase : public IEngine {
|
class AllreduceBase : public IEngine {
|
||||||
public:
|
public:
|
||||||
// magic number to verify server
|
// magic number to verify server
|
||||||
const static int kMagic = 0xff99;
|
const static int kMagic = 0xff99;
|
||||||
// constant one byte out of band message to indicate error happening
|
// constant one byte out of band message to indicate error happening
|
||||||
AllReduceBase(void);
|
AllreduceBase(void);
|
||||||
virtual ~AllReduceBase(void) {}
|
virtual ~AllreduceBase(void) {}
|
||||||
// initialize the manager
|
// initialize the manager
|
||||||
void Init(void);
|
void Init(void);
|
||||||
// shutdown the engine
|
// shutdown the engine
|
||||||
@ -65,12 +65,12 @@ class AllReduceBase : public IEngine {
|
|||||||
* \param count number of elements to be reduced
|
* \param count number of elements to be reduced
|
||||||
* \param reducer reduce function
|
* \param reducer reduce function
|
||||||
*/
|
*/
|
||||||
virtual void AllReduce(void *sendrecvbuf_,
|
virtual void Allreduce(void *sendrecvbuf_,
|
||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer) {
|
ReduceFunction reducer) {
|
||||||
utils::Assert(TryAllReduce(sendrecvbuf_, type_nbytes, count, reducer) == kSuccess,
|
utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) == kSuccess,
|
||||||
"AllReduce failed");
|
"Allreduce failed");
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief broadcast data from root to all nodes
|
* \brief broadcast data from root to all nodes
|
||||||
@ -80,7 +80,7 @@ class AllReduceBase : public IEngine {
|
|||||||
*/
|
*/
|
||||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||||
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
|
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
|
||||||
"AllReduce failed");
|
"Allreduce failed");
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief load latest check point
|
* \brief load latest check point
|
||||||
@ -171,7 +171,7 @@ class AllReduceBase : public IEngine {
|
|||||||
*/
|
*/
|
||||||
inline bool ReadToRingBuffer(size_t protect_start) {
|
inline bool ReadToRingBuffer(size_t protect_start) {
|
||||||
size_t ngap = size_read - protect_start;
|
size_t ngap = size_read - protect_start;
|
||||||
utils::Assert(ngap <= buffer_size, "AllReduce: boundary check");
|
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
|
||||||
size_t offset = size_read % buffer_size;
|
size_t offset = size_read % buffer_size;
|
||||||
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
|
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
|
||||||
if (nmax == 0) return true;
|
if (nmax == 0) return true;
|
||||||
@ -225,10 +225,10 @@ class AllReduceBase : public IEngine {
|
|||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
||||||
*
|
*
|
||||||
* NOTE on AllReduce:
|
* NOTE on Allreduce:
|
||||||
* The kSuccess TryAllReduce does NOT mean every node have successfully finishes TryAllReduce.
|
* The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
|
||||||
* It only means the current node get the correct result of AllReduce.
|
* It only means the current node get the correct result of Allreduce.
|
||||||
* However, it means every node finishes LAST call(instead of this one) of AllReduce/Bcast
|
* However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
|
||||||
*
|
*
|
||||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
* \param type_nbytes the unit number of bytes the type have
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
@ -237,7 +237,7 @@ class AllReduceBase : public IEngine {
|
|||||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
* \sa ReturnType
|
* \sa ReturnType
|
||||||
*/
|
*/
|
||||||
ReturnType TryAllReduce(void *sendrecvbuf_,
|
ReturnType TryAllreduce(void *sendrecvbuf_,
|
||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer);
|
ReduceFunction reducer);
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
/*!
|
/*!
|
||||||
* \file allreduce_robust-inl.h
|
* \file allreduce_robust-inl.h
|
||||||
* \brief implementation of inline template function in AllReduceRobust
|
* \brief implementation of inline template function in AllreduceRobust
|
||||||
*
|
*
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
@ -29,8 +29,8 @@ namespace engine {
|
|||||||
* \tparam NodeType type of node value
|
* \tparam NodeType type of node value
|
||||||
*/
|
*/
|
||||||
template<typename NodeType, typename EdgeType>
|
template<typename NodeType, typename EdgeType>
|
||||||
inline AllReduceRobust::ReturnType
|
inline AllreduceRobust::ReturnType
|
||||||
AllReduceRobust::MsgPassing(const NodeType &node_value,
|
AllreduceRobust::MsgPassing(const NodeType &node_value,
|
||||||
std::vector<EdgeType> *p_edge_in,
|
std::vector<EdgeType> *p_edge_in,
|
||||||
std::vector<EdgeType> *p_edge_out,
|
std::vector<EdgeType> *p_edge_out,
|
||||||
EdgeType (*func) (const NodeType &node_value,
|
EdgeType (*func) (const NodeType &node_value,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
/*!
|
/*!
|
||||||
* \file allreduce_robust.cc
|
* \file allreduce_robust.cc
|
||||||
* \brief Robust implementation of AllReduce
|
* \brief Robust implementation of Allreduce
|
||||||
*
|
*
|
||||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||||
*/
|
*/
|
||||||
@ -15,12 +15,12 @@
|
|||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
AllReduceRobust::AllReduceRobust(void) {
|
AllreduceRobust::AllreduceRobust(void) {
|
||||||
result_buffer_round = 1;
|
result_buffer_round = 1;
|
||||||
seq_counter = 0;
|
seq_counter = 0;
|
||||||
}
|
}
|
||||||
/*! \brief shutdown the engine */
|
/*! \brief shutdown the engine */
|
||||||
void AllReduceRobust::Shutdown(void) {
|
void AllreduceRobust::Shutdown(void) {
|
||||||
// need to sync the exec before we shutdown, do a pesudo check point
|
// need to sync the exec before we shutdown, do a pesudo check point
|
||||||
// execute checkpoint, note: when checkpoint existing, load will not happen
|
// execute checkpoint, note: when checkpoint existing, load will not happen
|
||||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq),
|
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq),
|
||||||
@ -30,15 +30,15 @@ void AllReduceRobust::Shutdown(void) {
|
|||||||
// execute check ack step, load happens here
|
// execute check ack step, load happens here
|
||||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
|
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
|
||||||
"check ack must return true");
|
"check ack must return true");
|
||||||
AllReduceBase::Shutdown();
|
AllreduceBase::Shutdown();
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief set parameters to the engine
|
* \brief set parameters to the engine
|
||||||
* \param name parameter name
|
* \param name parameter name
|
||||||
* \param val parameter value
|
* \param val parameter value
|
||||||
*/
|
*/
|
||||||
void AllReduceRobust::SetParam(const char *name, const char *val) {
|
void AllreduceRobust::SetParam(const char *name, const char *val) {
|
||||||
AllReduceBase::SetParam(name, val);
|
AllreduceBase::SetParam(name, val);
|
||||||
if (!strcmp(name, "result_buffer_round")) result_buffer_round = atoi(val);
|
if (!strcmp(name, "result_buffer_round")) result_buffer_round = atoi(val);
|
||||||
if (!strcmp(name, "result_replicate")) {
|
if (!strcmp(name, "result_replicate")) {
|
||||||
result_buffer_round = std::max(world_size / atoi(val), 1);
|
result_buffer_round = std::max(world_size / atoi(val), 1);
|
||||||
@ -52,7 +52,7 @@ void AllReduceRobust::SetParam(const char *name, const char *val) {
|
|||||||
* \param count number of elements to be reduced
|
* \param count number of elements to be reduced
|
||||||
* \param reducer reduce function
|
* \param reducer reduce function
|
||||||
*/
|
*/
|
||||||
void AllReduceRobust::AllReduce(void *sendrecvbuf_,
|
void AllreduceRobust::Allreduce(void *sendrecvbuf_,
|
||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer) {
|
ReduceFunction reducer) {
|
||||||
@ -68,7 +68,7 @@ void AllReduceRobust::AllReduce(void *sendrecvbuf_,
|
|||||||
std::memcpy(temp, sendrecvbuf_, type_nbytes * count); break;
|
std::memcpy(temp, sendrecvbuf_, type_nbytes * count); break;
|
||||||
} else {
|
} else {
|
||||||
std::memcpy(temp, sendrecvbuf_, type_nbytes * count);
|
std::memcpy(temp, sendrecvbuf_, type_nbytes * count);
|
||||||
if (CheckAndRecover(TryAllReduce(temp, type_nbytes, count, reducer))) {
|
if (CheckAndRecover(TryAllreduce(temp, type_nbytes, count, reducer))) {
|
||||||
std::memcpy(sendrecvbuf_, temp, type_nbytes * count); break;
|
std::memcpy(sendrecvbuf_, temp, type_nbytes * count); break;
|
||||||
} else {
|
} else {
|
||||||
recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
|
recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
|
||||||
@ -84,7 +84,7 @@ void AllReduceRobust::AllReduce(void *sendrecvbuf_,
|
|||||||
* \param size the size of the data to be broadcasted
|
* \param size the size of the data to be broadcasted
|
||||||
* \param root the root worker id to broadcast the data
|
* \param root the root worker id to broadcast the data
|
||||||
*/
|
*/
|
||||||
void AllReduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||||
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
|
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
|
||||||
// now we are free to remove the last result, if any
|
// now we are free to remove the last result, if any
|
||||||
if (resbuf.LastSeqNo() != -1 &&
|
if (resbuf.LastSeqNo() != -1 &&
|
||||||
@ -114,7 +114,7 @@ void AllReduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
|
|||||||
* the p_model is not touched, user should do necessary initialization by themselves
|
* the p_model is not touched, user should do necessary initialization by themselves
|
||||||
* \sa CheckPoint, VersionNumber
|
* \sa CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
int AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
|
int AllreduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
|
||||||
// check if we succesfll
|
// check if we succesfll
|
||||||
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kMaxSeq)) {
|
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kMaxSeq)) {
|
||||||
// reset result buffer
|
// reset result buffer
|
||||||
@ -142,7 +142,7 @@ int AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
|
|||||||
* \param p_model pointer to the model
|
* \param p_model pointer to the model
|
||||||
* \sa LoadCheckPoint, VersionNumber
|
* \sa LoadCheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
void AllReduceRobust::CheckPoint(const utils::ISerializable &model) {
|
void AllreduceRobust::CheckPoint(const utils::ISerializable &model) {
|
||||||
// increase version number
|
// increase version number
|
||||||
version_number += 1;
|
version_number += 1;
|
||||||
// save model
|
// save model
|
||||||
@ -168,7 +168,7 @@ void AllReduceRobust::CheckPoint(const utils::ISerializable &model) {
|
|||||||
* when kSockError is returned, it simply means there are bad sockets in the links,
|
* when kSockError is returned, it simply means there are bad sockets in the links,
|
||||||
* and some link recovery proceduer is needed
|
* and some link recovery proceduer is needed
|
||||||
*/
|
*/
|
||||||
AllReduceRobust::ReturnType AllReduceRobust::TryResetLinks(void) {
|
AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
||||||
// number of links
|
// number of links
|
||||||
const int nlink = static_cast<int>(links.size());
|
const int nlink = static_cast<int>(links.size());
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
@ -285,7 +285,7 @@ AllReduceRobust::ReturnType AllReduceRobust::TryResetLinks(void) {
|
|||||||
* \brief try to reconnect the broken links
|
* \brief try to reconnect the broken links
|
||||||
* \return this function can kSuccess or kSockError
|
* \return this function can kSuccess or kSockError
|
||||||
*/
|
*/
|
||||||
AllReduceRobust::ReturnType AllReduceRobust::TryReConnectLinks(void) {
|
AllreduceRobust::ReturnType AllreduceRobust::TryReConnectLinks(void) {
|
||||||
utils::Error("TryReConnectLinks: not implemented");
|
utils::Error("TryReConnectLinks: not implemented");
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
@ -296,7 +296,7 @@ AllReduceRobust::ReturnType AllReduceRobust::TryReConnectLinks(void) {
|
|||||||
* \param err_type the type of error happening in the system
|
* \param err_type the type of error happening in the system
|
||||||
* \return true if err_type is kSuccess, false otherwise
|
* \return true if err_type is kSuccess, false otherwise
|
||||||
*/
|
*/
|
||||||
bool AllReduceRobust::CheckAndRecover(ReturnType err_type) {
|
bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
|
||||||
if (err_type == kSuccess) return true;
|
if (err_type == kSuccess) return true;
|
||||||
while(err_type != kSuccess) {
|
while(err_type != kSuccess) {
|
||||||
switch(err_type) {
|
switch(err_type) {
|
||||||
@ -383,8 +383,8 @@ inline char DataRequest(const std::pair<bool, int> &node_value,
|
|||||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||||
* \sa ReturnType
|
* \sa ReturnType
|
||||||
*/
|
*/
|
||||||
AllReduceRobust::ReturnType
|
AllreduceRobust::ReturnType
|
||||||
AllReduceRobust::TryDecideRouting(AllReduceRobust::RecoverType role,
|
AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role,
|
||||||
size_t *p_size,
|
size_t *p_size,
|
||||||
int *p_recvlink,
|
int *p_recvlink,
|
||||||
std::vector<bool> *p_req_in) {
|
std::vector<bool> *p_req_in) {
|
||||||
@ -398,7 +398,7 @@ AllReduceRobust::TryDecideRouting(AllReduceRobust::RecoverType role,
|
|||||||
for (size_t i = 0; i < dist_in.size(); ++i) {
|
for (size_t i = 0; i < dist_in.size(); ++i) {
|
||||||
if (dist_in[i].first != std::numeric_limits<int>::max()) {
|
if (dist_in[i].first != std::numeric_limits<int>::max()) {
|
||||||
utils::Check(best_link == -2 || *p_size == dist_in[i].second,
|
utils::Check(best_link == -2 || *p_size == dist_in[i].second,
|
||||||
"[%d] AllReduce size inconsistent, distin=%lu, size=%lu, reporting=%lu\n",
|
"[%d] Allreduce size inconsistent, distin=%lu, size=%lu, reporting=%lu\n",
|
||||||
rank, dist_in[i].first, *p_size, dist_in[i].second);
|
rank, dist_in[i].first, *p_size, dist_in[i].second);
|
||||||
if (best_link == -2 || dist_in[i].first < dist_in[best_link].first) {
|
if (best_link == -2 || dist_in[i].first < dist_in[best_link].first) {
|
||||||
best_link = static_cast<int>(i);
|
best_link = static_cast<int>(i);
|
||||||
@ -444,8 +444,8 @@ AllReduceRobust::TryDecideRouting(AllReduceRobust::RecoverType role,
|
|||||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||||
* \sa ReturnType, TryDecideRouting
|
* \sa ReturnType, TryDecideRouting
|
||||||
*/
|
*/
|
||||||
AllReduceRobust::ReturnType
|
AllreduceRobust::ReturnType
|
||||||
AllReduceRobust::TryRecoverData(RecoverType role,
|
AllreduceRobust::TryRecoverData(RecoverType role,
|
||||||
void *sendrecvbuf_,
|
void *sendrecvbuf_,
|
||||||
size_t size,
|
size_t size,
|
||||||
int recv_link,
|
int recv_link,
|
||||||
@ -546,7 +546,7 @@ AllReduceRobust::TryRecoverData(RecoverType role,
|
|||||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||||
* \sa ReturnType
|
* \sa ReturnType
|
||||||
*/
|
*/
|
||||||
AllReduceRobust::ReturnType AllReduceRobust::TryLoadCheckPoint(bool requester) {
|
AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
|
||||||
RecoverType role = requester ? kRequestData : kHaveData;
|
RecoverType role = requester ? kRequestData : kHaveData;
|
||||||
size_t size = this->checked_model.length();
|
size_t size = this->checked_model.length();
|
||||||
int recv_link;
|
int recv_link;
|
||||||
@ -573,8 +573,8 @@ AllReduceRobust::ReturnType AllReduceRobust::TryLoadCheckPoint(bool requester) {
|
|||||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||||
* \sa ReturnType
|
* \sa ReturnType
|
||||||
*/
|
*/
|
||||||
AllReduceRobust::ReturnType
|
AllreduceRobust::ReturnType
|
||||||
AllReduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) { RecoverType role;
|
AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) { RecoverType role;
|
||||||
if (!requester) {
|
if (!requester) {
|
||||||
sendrecvbuf = resbuf.Query(seqno, &size);
|
sendrecvbuf = resbuf.Query(seqno, &size);
|
||||||
role = sendrecvbuf != NULL ? kHaveData : kPassData;
|
role = sendrecvbuf != NULL ? kHaveData : kPassData;
|
||||||
@ -605,7 +605,7 @@ AllReduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
|
|||||||
* result by recovering procedure, the action is complete, no further action is needed
|
* result by recovering procedure, the action is complete, no further action is needed
|
||||||
* - false means this is the lastest action that has not yet been executed, need to execute the action
|
* - false means this is the lastest action that has not yet been executed, need to execute the action
|
||||||
*/
|
*/
|
||||||
bool AllReduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
|
bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
|
||||||
if (flag != 0) {
|
if (flag != 0) {
|
||||||
utils::Assert(seqno == ActionSummary::kMaxSeq, "must only set seqno for normal operations");
|
utils::Assert(seqno == ActionSummary::kMaxSeq, "must only set seqno for normal operations");
|
||||||
}
|
}
|
||||||
@ -615,7 +615,7 @@ bool AllReduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
|
|||||||
// action
|
// action
|
||||||
ActionSummary act = req;
|
ActionSummary act = req;
|
||||||
// get the reduced action
|
// get the reduced action
|
||||||
if (!CheckAndRecover(TryAllReduce(&act, sizeof(act), 1, ActionSummary::Reducer))) continue;
|
if (!CheckAndRecover(TryAllreduce(&act, sizeof(act), 1, ActionSummary::Reducer))) continue;
|
||||||
if (act.check_ack()) {
|
if (act.check_ack()) {
|
||||||
if (act.check_point()) {
|
if (act.check_point()) {
|
||||||
// if we also have check_point, do check point first
|
// if we also have check_point, do check point first
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
/*!
|
/*!
|
||||||
* \file allreduce_robust.h
|
* \file allreduce_robust.h
|
||||||
* \brief Robust implementation of AllReduce
|
* \brief Robust implementation of Allreduce
|
||||||
* using TCP non-block socket and tree-shape reduction.
|
* using TCP non-block socket and tree-shape reduction.
|
||||||
*
|
*
|
||||||
* This implementation considers the failure of nodes
|
* This implementation considers the failure of nodes
|
||||||
@ -16,10 +16,10 @@
|
|||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
/*! \brief implementation of fault tolerant all reduce engine */
|
/*! \brief implementation of fault tolerant all reduce engine */
|
||||||
class AllReduceRobust : public AllReduceBase {
|
class AllreduceRobust : public AllreduceBase {
|
||||||
public:
|
public:
|
||||||
AllReduceRobust(void);
|
AllreduceRobust(void);
|
||||||
virtual ~AllReduceRobust(void) {}
|
virtual ~AllreduceRobust(void) {}
|
||||||
/*! \brief shutdown the engine */
|
/*! \brief shutdown the engine */
|
||||||
virtual void Shutdown(void);
|
virtual void Shutdown(void);
|
||||||
/*!
|
/*!
|
||||||
@ -36,7 +36,7 @@ class AllReduceRobust : public AllReduceBase {
|
|||||||
* \param count number of elements to be reduced
|
* \param count number of elements to be reduced
|
||||||
* \param reducer reduce function
|
* \param reducer reduce function
|
||||||
*/
|
*/
|
||||||
virtual void AllReduce(void *sendrecvbuf_,
|
virtual void Allreduce(void *sendrecvbuf_,
|
||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer);
|
ReduceFunction reducer);
|
||||||
@ -142,7 +142,7 @@ class AllReduceRobust : public AllReduceBase {
|
|||||||
inline int flag(void) const {
|
inline int flag(void) const {
|
||||||
return seqcode & 15;
|
return seqcode & 15;
|
||||||
}
|
}
|
||||||
// reducer for AllReduce, used to get the result ActionSummary from all nodes
|
// reducer for Allreduce, used to get the result ActionSummary from all nodes
|
||||||
inline static void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
|
inline static void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
|
||||||
const ActionSummary *src = (const ActionSummary*)src_;
|
const ActionSummary *src = (const ActionSummary*)src_;
|
||||||
ActionSummary *dst = (ActionSummary*)dst_;
|
ActionSummary *dst = (ActionSummary*)dst_;
|
||||||
@ -162,7 +162,7 @@ class AllReduceRobust : public AllReduceBase {
|
|||||||
// internel sequence code
|
// internel sequence code
|
||||||
int seqcode;
|
int seqcode;
|
||||||
};
|
};
|
||||||
/*! \brief data structure to remember result of Bcast and AllReduce calls */
|
/*! \brief data structure to remember result of Bcast and Allreduce calls */
|
||||||
class ResultBuffer {
|
class ResultBuffer {
|
||||||
public:
|
public:
|
||||||
// constructor
|
// constructor
|
||||||
|
|||||||
@ -16,7 +16,7 @@
|
|||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
// singleton sync manager
|
// singleton sync manager
|
||||||
AllReduceRobust manager;
|
AllreduceRobust manager;
|
||||||
|
|
||||||
/*! \brief intiialize the synchronization module */
|
/*! \brief intiialize the synchronization module */
|
||||||
void Init(int argc, char *argv[]) {
|
void Init(int argc, char *argv[]) {
|
||||||
@ -38,13 +38,13 @@ IEngine *GetEngine(void) {
|
|||||||
return &manager;
|
return &manager;
|
||||||
}
|
}
|
||||||
// perform in-place allreduce, on sendrecvbuf
|
// perform in-place allreduce, on sendrecvbuf
|
||||||
void AllReduce_(void *sendrecvbuf,
|
void Allreduce_(void *sendrecvbuf,
|
||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
IEngine::ReduceFunction red,
|
IEngine::ReduceFunction red,
|
||||||
mpi::DataType dtype,
|
mpi::DataType dtype,
|
||||||
mpi::OpType op) {
|
mpi::OpType op) {
|
||||||
GetEngine()->AllReduce(sendrecvbuf, type_nbytes, count, red);
|
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red);
|
||||||
}
|
}
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
@ -16,7 +16,7 @@ class Datatype;
|
|||||||
namespace rabit {
|
namespace rabit {
|
||||||
/*! \brief core interface of engine */
|
/*! \brief core interface of engine */
|
||||||
namespace engine {
|
namespace engine {
|
||||||
/*! \brief interface of core AllReduce engine */
|
/*! \brief interface of core Allreduce engine */
|
||||||
class IEngine {
|
class IEngine {
|
||||||
public:
|
public:
|
||||||
/*!
|
/*!
|
||||||
@ -41,7 +41,7 @@ class IEngine {
|
|||||||
* \param count number of elements to be reduced
|
* \param count number of elements to be reduced
|
||||||
* \param reducer reduce function
|
* \param reducer reduce function
|
||||||
*/
|
*/
|
||||||
virtual void AllReduce(void *sendrecvbuf_,
|
virtual void Allreduce(void *sendrecvbuf_,
|
||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer) = 0;
|
ReduceFunction reducer) = 0;
|
||||||
@ -130,7 +130,7 @@ enum DataType {
|
|||||||
* \param dtype the data type
|
* \param dtype the data type
|
||||||
* \param op the reduce operator type
|
* \param op the reduce operator type
|
||||||
*/
|
*/
|
||||||
void AllReduce_(void *sendrecvbuf,
|
void Allreduce_(void *sendrecvbuf,
|
||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
IEngine::ReduceFunction red,
|
IEngine::ReduceFunction red,
|
||||||
|
|||||||
@ -20,11 +20,11 @@ class MPIEngine : public IEngine {
|
|||||||
MPIEngine(void) {
|
MPIEngine(void) {
|
||||||
version_number = 0;
|
version_number = 0;
|
||||||
}
|
}
|
||||||
virtual void AllReduce(void *sendrecvbuf_,
|
virtual void Allreduce(void *sendrecvbuf_,
|
||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer) {
|
ReduceFunction reducer) {
|
||||||
utils::Error("MPIEngine:: AllReduce is not supported, use AllReduce_ instead");
|
utils::Error("MPIEngine:: Allreduce is not supported, use Allreduce_ instead");
|
||||||
}
|
}
|
||||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
|
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
|
||||||
MPI::COMM_WORLD.Bcast(sendrecvbuf_, size, MPI::CHAR, root);
|
MPI::COMM_WORLD.Bcast(sendrecvbuf_, size, MPI::CHAR, root);
|
||||||
@ -103,7 +103,7 @@ inline MPI::Op GetOp(mpi::OpType otype) {
|
|||||||
return MPI::MAX;
|
return MPI::MAX;
|
||||||
}
|
}
|
||||||
// perform in-place allreduce, on sendrecvbuf
|
// perform in-place allreduce, on sendrecvbuf
|
||||||
void AllReduce_(void *sendrecvbuf,
|
void Allreduce_(void *sendrecvbuf,
|
||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
IEngine::ReduceFunction red,
|
IEngine::ReduceFunction red,
|
||||||
|
|||||||
@ -25,9 +25,9 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<typename OP>
|
template<typename OP>
|
||||||
inline void AllReduce(float *sendrecvbuf, size_t count) {
|
inline void Allreduce(float *sendrecvbuf, size_t count) {
|
||||||
utils::Assert(verify(allReduce), "[%d] error when calling allReduce", rank);
|
utils::Assert(verify(allReduce), "[%d] error when calling allReduce", rank);
|
||||||
rabit::AllReduce<OP>(sendrecvbuf, count);
|
rabit::Allreduce<OP>(sendrecvbuf, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool LoadCheckPoint(utils::ISerializable *p_model) {
|
inline bool LoadCheckPoint(utils::ISerializable *p_model) {
|
||||||
|
|||||||
@ -101,10 +101,10 @@ inline void Bcast(std::string *sendrecv_data, int root) {
|
|||||||
e->Broadcast(&(*sendrecv_data)[0], len, root);
|
e->Broadcast(&(*sendrecv_data)[0], len, root);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// perform inplace AllReduce
|
// perform inplace Allreduce
|
||||||
template<typename OP, typename DType>
|
template<typename OP, typename DType>
|
||||||
inline void AllReduce(DType *sendrecvbuf, size_t count) {
|
inline void Allreduce(DType *sendrecvbuf, size_t count) {
|
||||||
engine::AllReduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>,
|
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>,
|
||||||
engine::mpi::GetType<DType>(), OP::kType);
|
engine::mpi::GetType<DType>(), OP::kType);
|
||||||
}
|
}
|
||||||
// load latest check point
|
// load latest check point
|
||||||
|
|||||||
@ -2,9 +2,9 @@
|
|||||||
#define RABIT_RABIT_H
|
#define RABIT_RABIT_H
|
||||||
/*!
|
/*!
|
||||||
* \file rabit.h
|
* \file rabit.h
|
||||||
* \brief This file defines unified AllReduce/Broadcast interface of rabit
|
* \brief This file defines unified Allreduce/Broadcast interface of rabit
|
||||||
* The actual implementation is redirected to rabit engine
|
* The actual implementation is redirected to rabit engine
|
||||||
* Code only using this header can also compiled with MPI AllReduce(with no fault recovery),
|
* Code only using this header can also compiled with MPI Allreduce(with no fault recovery),
|
||||||
*
|
*
|
||||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||||
*/
|
*/
|
||||||
@ -54,7 +54,7 @@ inline void Bcast(std::string *sendrecv_data, int root);
|
|||||||
* Example Usage: the following code gives sum of the result
|
* Example Usage: the following code gives sum of the result
|
||||||
* vector<int> data(10);
|
* vector<int> data(10);
|
||||||
* ...
|
* ...
|
||||||
* AllReduce<op::Sum>(&data[0], data.size());
|
* Allreduce<op::Sum>(&data[0], data.size());
|
||||||
* ...
|
* ...
|
||||||
* \param sendrecvbuf buffer for both sending and recving data
|
* \param sendrecvbuf buffer for both sending and recving data
|
||||||
* \param count number of elements to be reduced
|
* \param count number of elements to be reduced
|
||||||
@ -62,7 +62,7 @@ inline void Bcast(std::string *sendrecv_data, int root);
|
|||||||
* \tparam DType type of data
|
* \tparam DType type of data
|
||||||
*/
|
*/
|
||||||
template<typename OP, typename DType>
|
template<typename OP, typename DType>
|
||||||
inline void AllReduce(DType *sendrecvbuf, size_t count);
|
inline void Allreduce(DType *sendrecvbuf, size_t count);
|
||||||
/*!
|
/*!
|
||||||
* \brief load latest check point
|
* \brief load latest check point
|
||||||
* \param p_model pointer to the model
|
* \param p_model pointer to the model
|
||||||
|
|||||||
@ -68,7 +68,7 @@ class Master:
|
|||||||
try:
|
try:
|
||||||
magic = slave.recvint()
|
magic = slave.recvint()
|
||||||
if magic != kMagic:
|
if magic != kMagic:
|
||||||
print 'invalid magic number=%d from %s' % (magic, s_addr[0])
|
print 'invalid magic number=%d from %s' % (magic, s_addr[0])
|
||||||
slave.sock.close()
|
slave.sock.close()
|
||||||
continue
|
continue
|
||||||
except socket.error:
|
except socket.error:
|
||||||
|
|||||||
@ -15,7 +15,7 @@ inline void TestMax(test::Mock &mock, size_t n) {
|
|||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
ndata[i] = (i * (rank+1)) % 111;
|
ndata[i] = (i * (rank+1)) % 111;
|
||||||
}
|
}
|
||||||
mock.AllReduce<op::Max>(&ndata[0], ndata.size());
|
mock.Allreduce<op::Max>(&ndata[0], ndata.size());
|
||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
float rmax = (i * 1) % 111;
|
float rmax = (i * 1) % 111;
|
||||||
for (int r = 0; r < nproc; ++r) {
|
for (int r = 0; r < nproc; ++r) {
|
||||||
@ -34,7 +34,7 @@ inline void TestSum(test::Mock &mock, size_t n) {
|
|||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
ndata[i] = (i * (rank+1)) % z;
|
ndata[i] = (i * (rank+1)) % z;
|
||||||
}
|
}
|
||||||
mock.AllReduce<op::Sum>(&ndata[0], ndata.size());
|
mock.Allreduce<op::Sum>(&ndata[0], ndata.size());
|
||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
float rsum = 0.0f;
|
float rsum = 0.0f;
|
||||||
for (int r = 0; r < nproc; ++r) {
|
for (int r = 0; r < nproc; ++r) {
|
||||||
|
|||||||
@ -39,7 +39,7 @@ inline void TestMax(test::Mock &mock, Model *model, int ntrial, int iter) {
|
|||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
ndata[i] = (i * (rank+1)) % z + model->data[i];
|
ndata[i] = (i * (rank+1)) % z + model->data[i];
|
||||||
}
|
}
|
||||||
mock.AllReduce<op::Max>(&ndata[0], ndata.size());
|
mock.Allreduce<op::Max>(&ndata[0], ndata.size());
|
||||||
if (ntrial == iter && rank == 3) {
|
if (ntrial == iter && rank == 3) {
|
||||||
throw MockException();
|
throw MockException();
|
||||||
}
|
}
|
||||||
@ -62,7 +62,7 @@ inline void TestSum(test::Mock &mock, Model *model, int ntrial, int iter) {
|
|||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
ndata[i] = (i * (rank+1)) % z + model->data[i];
|
ndata[i] = (i * (rank+1)) % z + model->data[i];
|
||||||
}
|
}
|
||||||
mock.AllReduce<op::Sum>(&ndata[0], ndata.size());
|
mock.Allreduce<op::Sum>(&ndata[0], ndata.size());
|
||||||
|
|
||||||
if (ntrial == iter && rank == 0) {
|
if (ntrial == iter && rank == 0) {
|
||||||
throw MockException();
|
throw MockException();
|
||||||
|
|||||||
@ -18,7 +18,7 @@ inline void TestMax(test::Mock &mock, size_t n, int ntrial) {
|
|||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
ndata[i] = (i * (rank+1)) % 111;
|
ndata[i] = (i * (rank+1)) % 111;
|
||||||
}
|
}
|
||||||
mock.AllReduce<op::Max>(&ndata[0], ndata.size());
|
mock.Allreduce<op::Max>(&ndata[0], ndata.size());
|
||||||
if (ntrial == 0 && rank == 15) throw MockException();
|
if (ntrial == 0 && rank == 15) throw MockException();
|
||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
float rmax = (i * 1) % 111;
|
float rmax = (i * 1) % 111;
|
||||||
@ -38,7 +38,7 @@ inline void TestSum(test::Mock &mock, size_t n, int ntrial) {
|
|||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
ndata[i] = (i * (rank+1)) % z;
|
ndata[i] = (i * (rank+1)) % z;
|
||||||
}
|
}
|
||||||
mock.AllReduce<op::Sum>(&ndata[0], ndata.size());
|
mock.Allreduce<op::Sum>(&ndata[0], ndata.size());
|
||||||
|
|
||||||
if (ntrial == 0 && rank == 0) throw MockException();
|
if (ntrial == 0 && rank == 0) throw MockException();
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user