diff --git a/src/allreduce_mock.h b/src/allreduce_mock.h index 33b8b60ae..11005f4ba 100644 --- a/src/allreduce_mock.h +++ b/src/allreduce_mock.h @@ -9,7 +9,9 @@ #define RABIT_ALLREDUCE_MOCK_H #include #include +#include #include "../include/rabit/engine.h" +#include "../include/rabit/timer.h" #include "./allreduce_robust.h" namespace rabit { @@ -19,6 +21,9 @@ class AllreduceMock : public AllreduceRobust { // constructor AllreduceMock(void) { num_trial = 0; + force_local = 0; + report_stats = 0; + tsum_allreduce = 0.0; } // destructor virtual ~AllreduceMock(void) {} @@ -26,6 +31,8 @@ class AllreduceMock : public AllreduceRobust { AllreduceRobust::SetParam(name, val); // additional parameters if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val); + if (!strcmp(name, "report_stats")) report_stats = atoi(val); + if (!strcmp(name, "force_local")) force_local = atoi(val); if (!strcmp(name, "mock")) { MockKey k; utils::Check(sscanf(val, "%d,%d,%d,%d", @@ -41,25 +48,92 @@ class AllreduceMock : public AllreduceRobust { PreprocFunction prepare_fun, void *prepare_arg) { this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce"); + double tstart = utils::GetTime(); AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes, count, reducer, prepare_fun, prepare_arg); + tsum_allreduce += utils::GetTime() - tstart; } virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) { this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast"); AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root); } + virtual int LoadCheckPoint(ISerializable *global_model, + ISerializable *local_model) { + tsum_allreduce = 0.0; + time_checkpoint = utils::GetTime(); + if (force_local == 0) { + return AllreduceRobust::LoadCheckPoint(global_model, local_model); + } else { + DummySerializer dum; + ComboSerializer com(global_model, local_model); + return AllreduceRobust::LoadCheckPoint(&dum, &com); + } + } virtual void CheckPoint(const ISerializable *global_model, const ISerializable *local_model) { this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint"); - AllreduceRobust::CheckPoint(global_model, local_model); + double tstart = utils::GetTime(); + double tbet_chkpt = tstart - time_checkpoint; + if (force_local == 0) { + AllreduceRobust::CheckPoint(global_model, local_model); + } else { + DummySerializer dum; + ComboSerializer com(global_model, local_model); + AllreduceRobust::CheckPoint(&dum, &com); + } + time_checkpoint = utils::GetTime(); + double tcost = utils::GetTime() - tstart; + if (report_stats != 0 && rank == 0) { + std::stringstream ss; + ss << "[v" << version_number << "] global_size=" << global_checkpoint.length() + << "local_size=" << local_chkpt[local_chkpt_version].length() + << "check_tcost="<< tcost <<" sec," + << "allreduce_tcost=" << tsum_allreduce << " sec," + << "between_chpt=" << tbet_chkpt << "sec\n"; + this->TrackerPrint(ss.str()); + } } virtual void LazyCheckPoint(const ISerializable *global_model) { this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint"); AllreduceRobust::LazyCheckPoint(global_model); } + protected: + // force checkpoint to local + int force_local; + // whether report statistics + int report_stats; + // sum of allreduce + double tsum_allreduce; + double time_checkpoint; private: + struct DummySerializer : public ISerializable { + virtual void Load(IStream &fi) { + } + virtual void Save(IStream &fo) const { + } + }; + struct ComboSerializer : public ISerializable { + ISerializable *lhs; + ISerializable *rhs; + const ISerializable *c_lhs; + const ISerializable *c_rhs; + ComboSerializer(ISerializable *lhs, ISerializable *rhs) + : lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) { + } + ComboSerializer(const ISerializable *lhs, const ISerializable *rhs) + : lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) { + } + virtual void Load(IStream &fi) { + if (lhs != NULL) lhs->Load(fi); + if (rhs != NULL) rhs->Load(fi); + } + virtual void Save(IStream &fo) const { + if (c_lhs != NULL) c_lhs->Save(fo); + if (c_rhs != NULL) c_rhs->Save(fo); + } + }; // key to identify the mock stage struct MockKey { int rank; diff --git a/src/allreduce_robust.h b/src/allreduce_robust.h index ff5c046ac..34b740dfe 100644 --- a/src/allreduce_robust.h +++ b/src/allreduce_robust.h @@ -138,7 +138,7 @@ class AllreduceRobust : public AllreduceBase { ReConnectLinks("recover"); } - private: + protected: // constant one byte out of band message to indicate error happening // and mark for channel cleanup static const char kOOBReset = 95;