add mock option statis
This commit is contained in:
parent
75c647cd84
commit
d29892cb22
@ -9,7 +9,9 @@
|
||||
#define RABIT_ALLREDUCE_MOCK_H
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include "../include/rabit/engine.h"
|
||||
#include "../include/rabit/timer.h"
|
||||
#include "./allreduce_robust.h"
|
||||
|
||||
namespace rabit {
|
||||
@ -19,6 +21,9 @@ class AllreduceMock : public AllreduceRobust {
|
||||
// constructor
|
||||
AllreduceMock(void) {
|
||||
num_trial = 0;
|
||||
force_local = 0;
|
||||
report_stats = 0;
|
||||
tsum_allreduce = 0.0;
|
||||
}
|
||||
// destructor
|
||||
virtual ~AllreduceMock(void) {}
|
||||
@ -26,6 +31,8 @@ class AllreduceMock : public AllreduceRobust {
|
||||
AllreduceRobust::SetParam(name, val);
|
||||
// additional parameters
|
||||
if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val);
|
||||
if (!strcmp(name, "report_stats")) report_stats = atoi(val);
|
||||
if (!strcmp(name, "force_local")) force_local = atoi(val);
|
||||
if (!strcmp(name, "mock")) {
|
||||
MockKey k;
|
||||
utils::Check(sscanf(val, "%d,%d,%d,%d",
|
||||
@ -41,25 +48,92 @@ class AllreduceMock : public AllreduceRobust {
|
||||
PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
|
||||
double tstart = utils::GetTime();
|
||||
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
|
||||
count, reducer, prepare_fun, prepare_arg);
|
||||
tsum_allreduce += utils::GetTime() - tstart;
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
|
||||
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root);
|
||||
}
|
||||
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model) {
|
||||
tsum_allreduce = 0.0;
|
||||
time_checkpoint = utils::GetTime();
|
||||
if (force_local == 0) {
|
||||
return AllreduceRobust::LoadCheckPoint(global_model, local_model);
|
||||
} else {
|
||||
DummySerializer dum;
|
||||
ComboSerializer com(global_model, local_model);
|
||||
return AllreduceRobust::LoadCheckPoint(&dum, &com);
|
||||
}
|
||||
}
|
||||
virtual void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
|
||||
double tstart = utils::GetTime();
|
||||
double tbet_chkpt = tstart - time_checkpoint;
|
||||
if (force_local == 0) {
|
||||
AllreduceRobust::CheckPoint(global_model, local_model);
|
||||
} else {
|
||||
DummySerializer dum;
|
||||
ComboSerializer com(global_model, local_model);
|
||||
AllreduceRobust::CheckPoint(&dum, &com);
|
||||
}
|
||||
time_checkpoint = utils::GetTime();
|
||||
double tcost = utils::GetTime() - tstart;
|
||||
if (report_stats != 0 && rank == 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[v" << version_number << "] global_size=" << global_checkpoint.length()
|
||||
<< "local_size=" << local_chkpt[local_chkpt_version].length()
|
||||
<< "check_tcost="<< tcost <<" sec,"
|
||||
<< "allreduce_tcost=" << tsum_allreduce << " sec,"
|
||||
<< "between_chpt=" << tbet_chkpt << "sec\n";
|
||||
this->TrackerPrint(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
|
||||
AllreduceRobust::LazyCheckPoint(global_model);
|
||||
}
|
||||
protected:
|
||||
// force checkpoint to local
|
||||
int force_local;
|
||||
// whether report statistics
|
||||
int report_stats;
|
||||
// sum of allreduce
|
||||
double tsum_allreduce;
|
||||
double time_checkpoint;
|
||||
|
||||
private:
|
||||
struct DummySerializer : public ISerializable {
|
||||
virtual void Load(IStream &fi) {
|
||||
}
|
||||
virtual void Save(IStream &fo) const {
|
||||
}
|
||||
};
|
||||
struct ComboSerializer : public ISerializable {
|
||||
ISerializable *lhs;
|
||||
ISerializable *rhs;
|
||||
const ISerializable *c_lhs;
|
||||
const ISerializable *c_rhs;
|
||||
ComboSerializer(ISerializable *lhs, ISerializable *rhs)
|
||||
: lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) {
|
||||
}
|
||||
ComboSerializer(const ISerializable *lhs, const ISerializable *rhs)
|
||||
: lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) {
|
||||
}
|
||||
virtual void Load(IStream &fi) {
|
||||
if (lhs != NULL) lhs->Load(fi);
|
||||
if (rhs != NULL) rhs->Load(fi);
|
||||
}
|
||||
virtual void Save(IStream &fo) const {
|
||||
if (c_lhs != NULL) c_lhs->Save(fo);
|
||||
if (c_rhs != NULL) c_rhs->Save(fo);
|
||||
}
|
||||
};
|
||||
// key to identify the mock stage
|
||||
struct MockKey {
|
||||
int rank;
|
||||
|
||||
@ -138,7 +138,7 @@ class AllreduceRobust : public AllreduceBase {
|
||||
ReConnectLinks("recover");
|
||||
}
|
||||
|
||||
private:
|
||||
protected:
|
||||
// constant one byte out of band message to indicate error happening
|
||||
// and mark for channel cleanup
|
||||
static const char kOOBReset = 95;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user