rabit seems ready to run

This commit is contained in:
tqchen 2014-12-01 10:32:30 -08:00
parent 0d63646015
commit 1c5167d96e
3 changed files with 55 additions and 7 deletions

View File

@ -9,6 +9,7 @@
#define NOMINMAX
#include <limits>
#include <utility>
#include "./io.h"
#include "./utils.h"
#include "./engine_robust.h"
@ -16,6 +17,7 @@ namespace rabit {
namespace engine {
AllReduceRobust::AllReduceRobust(void) {
result_buffer_round = 1;
seq_counter = 0;
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf
@ -58,9 +60,26 @@ void AllReduceRobust::AllReduce(void *sendrecvbuf_,
* \param root the root worker id to broadcast the data
*/
void AllReduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
"AllReduce failed");
// TODO
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
// now we are free to remove the last result, if any
if (resbuf.LastSeqNo() != -1 &&
(resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
resbuf.DropLast();
}
void *temp = resbuf.AllocTemp(1, total_size);
while (true) {
if (recovered) {
std::memcpy(temp, sendrecvbuf_, total_size); break;
} else {
if (CheckAndRecover(TryBroadcast(sendrecvbuf_, total_size, root))) {
std::memcpy(temp, sendrecvbuf_, total_size); break;
} else {
recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
}
}
}
resbuf.PushTemp(seq_counter, 1, total_size);
seq_counter += 1;
}
/*!
* \brief load latest check point
@ -69,15 +88,43 @@ void AllReduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
* false if there was no stored checkpoint, means we are start over gain
*/
bool AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
// TODO
return false;
// check if we succesfll
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kMaxSeq)) {
// if loaded model is empty, this simply means we did not call checkpoint yet
// ask caller to reinit model
if (checked_model.length() == 0) return false;
// load from buffer
utils::MemoryBufferStream fs(&checked_model);
p_model->Load(fs);
// reset result buffer
resbuf.Clear(); seq_counter = 0;
// run another phase of check ack, if recovered from data
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
"check ack must return true");
return true;
} else {
// nothing loaded, a fresh start, everyone init model
return false;
}
}
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* \param p_model pointer to the model
*/
void AllReduceRobust::CheckPoint(const utils::ISerializable &model) {
// TODO
// save model
checked_model.resize(0);
utils::MemoryBufferStream fs(&checked_model);
model.Save(fs);
utils::Check(checked_model.length() != 0, "CheckPoint: empty model, model.Save must save something");
// execute checkpoint, note: when checkpoint existing, load will not happen
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq),
"check point must return true");
// reset result buffer
resbuf.Clear(); seq_counter = 0;
// execute check ack step, load happens here
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
"check ack must return true");
}
/*!
* \brief reset the all the existing links by sending Out-of-Band message marker

View File

@ -154,7 +154,7 @@ class AllReduceRobust : public AllReduceBase {
rptr_.clear(); rptr_.push_back(0);
data_.clear();
}
// allocate temporal space for
// allocate temporal space
inline void *AllocTemp(size_t type_nbytes, size_t count) {
size_t size = type_nbytes * count;
size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t);

View File

@ -91,6 +91,7 @@ class IStream {
/*! \brief interface of se*/
class ISerializable {
public:
/*! \brief load the model from file */
virtual void Load(IStream &fi) = 0;
/*! \brief save the model to the stream*/