rabit seems ready to run
This commit is contained in:
parent
0d63646015
commit
1c5167d96e
@ -9,6 +9,7 @@
|
||||
#define NOMINMAX
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
#include "./io.h"
|
||||
#include "./utils.h"
|
||||
#include "./engine_robust.h"
|
||||
|
||||
@ -16,6 +17,7 @@ namespace rabit {
|
||||
namespace engine {
|
||||
AllReduceRobust::AllReduceRobust(void) {
|
||||
result_buffer_round = 1;
|
||||
seq_counter = 0;
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
@ -58,9 +60,26 @@ void AllReduceRobust::AllReduce(void *sendrecvbuf_,
|
||||
* \param root the root worker id to broadcast the data
|
||||
*/
|
||||
void AllReduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
|
||||
"AllReduce failed");
|
||||
// TODO
|
||||
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
|
||||
// now we are free to remove the last result, if any
|
||||
if (resbuf.LastSeqNo() != -1 &&
|
||||
(resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
|
||||
resbuf.DropLast();
|
||||
}
|
||||
void *temp = resbuf.AllocTemp(1, total_size);
|
||||
while (true) {
|
||||
if (recovered) {
|
||||
std::memcpy(temp, sendrecvbuf_, total_size); break;
|
||||
} else {
|
||||
if (CheckAndRecover(TryBroadcast(sendrecvbuf_, total_size, root))) {
|
||||
std::memcpy(temp, sendrecvbuf_, total_size); break;
|
||||
} else {
|
||||
recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
|
||||
}
|
||||
}
|
||||
}
|
||||
resbuf.PushTemp(seq_counter, 1, total_size);
|
||||
seq_counter += 1;
|
||||
}
|
||||
/*!
|
||||
* \brief load latest check point
|
||||
@ -69,15 +88,43 @@ void AllReduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
|
||||
* false if there was no stored checkpoint, means we are start over gain
|
||||
*/
|
||||
bool AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
|
||||
// TODO
|
||||
return false;
|
||||
// check if we succesfll
|
||||
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kMaxSeq)) {
|
||||
// if loaded model is empty, this simply means we did not call checkpoint yet
|
||||
// ask caller to reinit model
|
||||
if (checked_model.length() == 0) return false;
|
||||
// load from buffer
|
||||
utils::MemoryBufferStream fs(&checked_model);
|
||||
p_model->Load(fs);
|
||||
// reset result buffer
|
||||
resbuf.Clear(); seq_counter = 0;
|
||||
// run another phase of check ack, if recovered from data
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
|
||||
"check ack must return true");
|
||||
return true;
|
||||
} else {
|
||||
// nothing loaded, a fresh start, everyone init model
|
||||
return false;
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||
* \param p_model pointer to the model
|
||||
*/
|
||||
void AllReduceRobust::CheckPoint(const utils::ISerializable &model) {
|
||||
// TODO
|
||||
// save model
|
||||
checked_model.resize(0);
|
||||
utils::MemoryBufferStream fs(&checked_model);
|
||||
model.Save(fs);
|
||||
utils::Check(checked_model.length() != 0, "CheckPoint: empty model, model.Save must save something");
|
||||
// execute checkpoint, note: when checkpoint existing, load will not happen
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq),
|
||||
"check point must return true");
|
||||
// reset result buffer
|
||||
resbuf.Clear(); seq_counter = 0;
|
||||
// execute check ack step, load happens here
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
|
||||
"check ack must return true");
|
||||
}
|
||||
/*!
|
||||
* \brief reset the all the existing links by sending Out-of-Band message marker
|
||||
|
||||
@ -154,7 +154,7 @@ class AllReduceRobust : public AllReduceBase {
|
||||
rptr_.clear(); rptr_.push_back(0);
|
||||
data_.clear();
|
||||
}
|
||||
// allocate temporal space for
|
||||
// allocate temporal space
|
||||
inline void *AllocTemp(size_t type_nbytes, size_t count) {
|
||||
size_t size = type_nbytes * count;
|
||||
size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user