rabit seems ready to run
This commit is contained in:
parent
0d63646015
commit
1c5167d96e
@ -9,6 +9,7 @@
|
|||||||
#define NOMINMAX
|
#define NOMINMAX
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include "./io.h"
|
||||||
#include "./utils.h"
|
#include "./utils.h"
|
||||||
#include "./engine_robust.h"
|
#include "./engine_robust.h"
|
||||||
|
|
||||||
@ -16,6 +17,7 @@ namespace rabit {
|
|||||||
namespace engine {
|
namespace engine {
|
||||||
AllReduceRobust::AllReduceRobust(void) {
|
AllReduceRobust::AllReduceRobust(void) {
|
||||||
result_buffer_round = 1;
|
result_buffer_round = 1;
|
||||||
|
seq_counter = 0;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf
|
* \brief perform in-place allreduce, on sendrecvbuf
|
||||||
@ -58,9 +60,26 @@ void AllReduceRobust::AllReduce(void *sendrecvbuf_,
|
|||||||
* \param root the root worker id to broadcast the data
|
* \param root the root worker id to broadcast the data
|
||||||
*/
|
*/
|
||||||
void AllReduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
void AllReduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||||
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
|
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
|
||||||
"AllReduce failed");
|
// now we are free to remove the last result, if any
|
||||||
// TODO
|
if (resbuf.LastSeqNo() != -1 &&
|
||||||
|
(resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
|
||||||
|
resbuf.DropLast();
|
||||||
|
}
|
||||||
|
void *temp = resbuf.AllocTemp(1, total_size);
|
||||||
|
while (true) {
|
||||||
|
if (recovered) {
|
||||||
|
std::memcpy(temp, sendrecvbuf_, total_size); break;
|
||||||
|
} else {
|
||||||
|
if (CheckAndRecover(TryBroadcast(sendrecvbuf_, total_size, root))) {
|
||||||
|
std::memcpy(temp, sendrecvbuf_, total_size); break;
|
||||||
|
} else {
|
||||||
|
recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resbuf.PushTemp(seq_counter, 1, total_size);
|
||||||
|
seq_counter += 1;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief load latest check point
|
* \brief load latest check point
|
||||||
@ -69,15 +88,43 @@ void AllReduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
|
|||||||
* false if there was no stored checkpoint, means we are start over gain
|
* false if there was no stored checkpoint, means we are start over gain
|
||||||
*/
|
*/
|
||||||
bool AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
|
bool AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
|
||||||
// TODO
|
// check if we succesfll
|
||||||
|
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kMaxSeq)) {
|
||||||
|
// if loaded model is empty, this simply means we did not call checkpoint yet
|
||||||
|
// ask caller to reinit model
|
||||||
|
if (checked_model.length() == 0) return false;
|
||||||
|
// load from buffer
|
||||||
|
utils::MemoryBufferStream fs(&checked_model);
|
||||||
|
p_model->Load(fs);
|
||||||
|
// reset result buffer
|
||||||
|
resbuf.Clear(); seq_counter = 0;
|
||||||
|
// run another phase of check ack, if recovered from data
|
||||||
|
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
|
||||||
|
"check ack must return true");
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
// nothing loaded, a fresh start, everyone init model
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||||
* \param p_model pointer to the model
|
* \param p_model pointer to the model
|
||||||
*/
|
*/
|
||||||
void AllReduceRobust::CheckPoint(const utils::ISerializable &model) {
|
void AllReduceRobust::CheckPoint(const utils::ISerializable &model) {
|
||||||
// TODO
|
// save model
|
||||||
|
checked_model.resize(0);
|
||||||
|
utils::MemoryBufferStream fs(&checked_model);
|
||||||
|
model.Save(fs);
|
||||||
|
utils::Check(checked_model.length() != 0, "CheckPoint: empty model, model.Save must save something");
|
||||||
|
// execute checkpoint, note: when checkpoint existing, load will not happen
|
||||||
|
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq),
|
||||||
|
"check point must return true");
|
||||||
|
// reset result buffer
|
||||||
|
resbuf.Clear(); seq_counter = 0;
|
||||||
|
// execute check ack step, load happens here
|
||||||
|
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
|
||||||
|
"check ack must return true");
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief reset the all the existing links by sending Out-of-Band message marker
|
* \brief reset the all the existing links by sending Out-of-Band message marker
|
||||||
|
|||||||
@ -154,7 +154,7 @@ class AllReduceRobust : public AllReduceBase {
|
|||||||
rptr_.clear(); rptr_.push_back(0);
|
rptr_.clear(); rptr_.push_back(0);
|
||||||
data_.clear();
|
data_.clear();
|
||||||
}
|
}
|
||||||
// allocate temporal space for
|
// allocate temporal space
|
||||||
inline void *AllocTemp(size_t type_nbytes, size_t count) {
|
inline void *AllocTemp(size_t type_nbytes, size_t count) {
|
||||||
size_t size = type_nbytes * count;
|
size_t size = type_nbytes * count;
|
||||||
size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t);
|
size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t);
|
||||||
|
|||||||
1
src/io.h
1
src/io.h
@ -91,6 +91,7 @@ class IStream {
|
|||||||
|
|
||||||
/*! \brief interface of se*/
|
/*! \brief interface of se*/
|
||||||
class ISerializable {
|
class ISerializable {
|
||||||
|
public:
|
||||||
/*! \brief load the model from file */
|
/*! \brief load the model from file */
|
||||||
virtual void Load(IStream &fi) = 0;
|
virtual void Load(IStream &fi) = 0;
|
||||||
/*! \brief save the model to the stream*/
|
/*! \brief save the model to the stream*/
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user