add mock option statis

This commit is contained in:
tqchen 2015-03-02 16:10:08 -08:00
parent 75c647cd84
commit d29892cb22
2 changed files with 76 additions and 2 deletions

View File

@ -9,7 +9,9 @@
#define RABIT_ALLREDUCE_MOCK_H #define RABIT_ALLREDUCE_MOCK_H
#include <vector> #include <vector>
#include <map> #include <map>
#include <sstream>
#include "../include/rabit/engine.h" #include "../include/rabit/engine.h"
#include "../include/rabit/timer.h"
#include "./allreduce_robust.h" #include "./allreduce_robust.h"
namespace rabit { namespace rabit {
@ -19,6 +21,9 @@ class AllreduceMock : public AllreduceRobust {
// constructor // constructor
AllreduceMock(void) { AllreduceMock(void) {
num_trial = 0; num_trial = 0;
force_local = 0;
report_stats = 0;
tsum_allreduce = 0.0;
} }
// destructor // destructor
virtual ~AllreduceMock(void) {} virtual ~AllreduceMock(void) {}
@ -26,6 +31,8 @@ class AllreduceMock : public AllreduceRobust {
AllreduceRobust::SetParam(name, val); AllreduceRobust::SetParam(name, val);
// additional parameters // additional parameters
if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val); if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val);
if (!strcmp(name, "report_stats")) report_stats = atoi(val);
if (!strcmp(name, "force_local")) force_local = atoi(val);
if (!strcmp(name, "mock")) { if (!strcmp(name, "mock")) {
MockKey k; MockKey k;
utils::Check(sscanf(val, "%d,%d,%d,%d", utils::Check(sscanf(val, "%d,%d,%d,%d",
@ -41,25 +48,92 @@ class AllreduceMock : public AllreduceRobust {
PreprocFunction prepare_fun, PreprocFunction prepare_fun,
void *prepare_arg) { void *prepare_arg) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce"); this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
double tstart = utils::GetTime();
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes, AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
count, reducer, prepare_fun, prepare_arg); count, reducer, prepare_fun, prepare_arg);
tsum_allreduce += utils::GetTime() - tstart;
} }
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) { virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast"); this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root); AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root);
} }
virtual int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model) {
tsum_allreduce = 0.0;
time_checkpoint = utils::GetTime();
if (force_local == 0) {
return AllreduceRobust::LoadCheckPoint(global_model, local_model);
} else {
DummySerializer dum;
ComboSerializer com(global_model, local_model);
return AllreduceRobust::LoadCheckPoint(&dum, &com);
}
}
virtual void CheckPoint(const ISerializable *global_model, virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model) { const ISerializable *local_model) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint"); this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
AllreduceRobust::CheckPoint(global_model, local_model); double tstart = utils::GetTime();
double tbet_chkpt = tstart - time_checkpoint;
if (force_local == 0) {
AllreduceRobust::CheckPoint(global_model, local_model);
} else {
DummySerializer dum;
ComboSerializer com(global_model, local_model);
AllreduceRobust::CheckPoint(&dum, &com);
}
time_checkpoint = utils::GetTime();
double tcost = utils::GetTime() - tstart;
if (report_stats != 0 && rank == 0) {
std::stringstream ss;
ss << "[v" << version_number << "] global_size=" << global_checkpoint.length()
<< "local_size=" << local_chkpt[local_chkpt_version].length()
<< "check_tcost="<< tcost <<" sec,"
<< "allreduce_tcost=" << tsum_allreduce << " sec,"
<< "between_chpt=" << tbet_chkpt << "sec\n";
this->TrackerPrint(ss.str());
}
} }
virtual void LazyCheckPoint(const ISerializable *global_model) { virtual void LazyCheckPoint(const ISerializable *global_model) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint"); this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
AllreduceRobust::LazyCheckPoint(global_model); AllreduceRobust::LazyCheckPoint(global_model);
} }
protected:
// force checkpoint to local
int force_local;
// whether report statistics
int report_stats;
// sum of allreduce
double tsum_allreduce;
double time_checkpoint;
private: private:
struct DummySerializer : public ISerializable {
virtual void Load(IStream &fi) {
}
virtual void Save(IStream &fo) const {
}
};
struct ComboSerializer : public ISerializable {
ISerializable *lhs;
ISerializable *rhs;
const ISerializable *c_lhs;
const ISerializable *c_rhs;
ComboSerializer(ISerializable *lhs, ISerializable *rhs)
: lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) {
}
ComboSerializer(const ISerializable *lhs, const ISerializable *rhs)
: lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) {
}
virtual void Load(IStream &fi) {
if (lhs != NULL) lhs->Load(fi);
if (rhs != NULL) rhs->Load(fi);
}
virtual void Save(IStream &fo) const {
if (c_lhs != NULL) c_lhs->Save(fo);
if (c_rhs != NULL) c_rhs->Save(fo);
}
};
// key to identify the mock stage // key to identify the mock stage
struct MockKey { struct MockKey {
int rank; int rank;

View File

@ -138,7 +138,7 @@ class AllreduceRobust : public AllreduceBase {
ReConnectLinks("recover"); ReConnectLinks("recover");
} }
private: protected:
// constant one byte out of band message to indicate error happening // constant one byte out of band message to indicate error happening
// and mark for channel cleanup // and mark for channel cleanup
static const char kOOBReset = 95; static const char kOOBReset = 95;