Squashed 'subtree/rabit/' changes from 4db0a62..d4ec037

d4ec037 fix rabit
6612fcf Merge branch 'master' of ssh://github.com/tqchen/rabit
d29892c add mock option statis
4fa054e new tracker
75c647c update tracker for host IP
e4ce8ef add hadoop linear example
76ecb4a add hadoop linear example
2e1c4c9 add hadoop linear example

git-subtree-dir: subtree/rabit
git-subtree-split: d4ec037f2e
This commit is contained in:
tqchen
2015-03-03 13:13:21 -08:00
parent 13776a006a
commit ef2de29f06
7 changed files with 119 additions and 13 deletions

View File

@@ -9,7 +9,9 @@
#define RABIT_ALLREDUCE_MOCK_H
#include <vector>
#include <map>
#include <sstream>
#include "../include/rabit/engine.h"
#include "../include/rabit/timer.h"
#include "./allreduce_robust.h"
namespace rabit {
@@ -19,6 +21,9 @@ class AllreduceMock : public AllreduceRobust {
// constructor
AllreduceMock(void) {
num_trial = 0;
force_local = 0;
report_stats = 0;
tsum_allreduce = 0.0;
}
// destructor
virtual ~AllreduceMock(void) {}
@@ -26,6 +31,8 @@ class AllreduceMock : public AllreduceRobust {
AllreduceRobust::SetParam(name, val);
// additional parameters
if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val);
if (!strcmp(name, "report_stats")) report_stats = atoi(val);
if (!strcmp(name, "force_local")) force_local = atoi(val);
if (!strcmp(name, "mock")) {
MockKey k;
utils::Check(sscanf(val, "%d,%d,%d,%d",
@@ -41,25 +48,93 @@ class AllreduceMock : public AllreduceRobust {
PreprocFunction prepare_fun,
void *prepare_arg) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
double tstart = utils::GetTime();
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
count, reducer, prepare_fun, prepare_arg);
tsum_allreduce += utils::GetTime() - tstart;
}
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root);
}
virtual int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model) {
tsum_allreduce = 0.0;
time_checkpoint = utils::GetTime();
if (force_local == 0) {
return AllreduceRobust::LoadCheckPoint(global_model, local_model);
} else {
DummySerializer dum;
ComboSerializer com(global_model, local_model);
return AllreduceRobust::LoadCheckPoint(&dum, &com);
}
}
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
AllreduceRobust::CheckPoint(global_model, local_model);
double tstart = utils::GetTime();
double tbet_chkpt = tstart - time_checkpoint;
if (force_local == 0) {
AllreduceRobust::CheckPoint(global_model, local_model);
} else {
DummySerializer dum;
ComboSerializer com(global_model, local_model);
AllreduceRobust::CheckPoint(&dum, &com);
}
tsum_allreduce = 0.0;
time_checkpoint = utils::GetTime();
double tcost = utils::GetTime() - tstart;
if (report_stats != 0 && rank == 0) {
std::stringstream ss;
ss << "[v" << version_number << "] global_size=" << global_checkpoint.length()
<< "local_size=" << local_chkpt[local_chkpt_version].length()
<< "check_tcost="<< tcost <<" sec,"
<< "allreduce_tcost=" << tsum_allreduce << " sec,"
<< "between_chpt=" << tbet_chkpt << "sec\n";
this->TrackerPrint(ss.str());
}
}
virtual void LazyCheckPoint(const ISerializable *global_model) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
AllreduceRobust::LazyCheckPoint(global_model);
}
protected:
// force checkpoint to local
int force_local;
// whether report statistics
int report_stats;
// sum of allreduce
double tsum_allreduce;
double time_checkpoint;
private:
struct DummySerializer : public ISerializable {
virtual void Load(IStream &fi) {
}
virtual void Save(IStream &fo) const {
}
};
struct ComboSerializer : public ISerializable {
ISerializable *lhs;
ISerializable *rhs;
const ISerializable *c_lhs;
const ISerializable *c_rhs;
ComboSerializer(ISerializable *lhs, ISerializable *rhs)
: lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) {
}
ComboSerializer(const ISerializable *lhs, const ISerializable *rhs)
: lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) {
}
virtual void Load(IStream &fi) {
if (lhs != NULL) lhs->Load(fi);
if (rhs != NULL) rhs->Load(fi);
}
virtual void Save(IStream &fo) const {
if (c_lhs != NULL) c_lhs->Save(fo);
if (c_rhs != NULL) c_rhs->Save(fo);
}
};
// key to identify the mock stage
struct MockKey {
int rank;

View File

@@ -138,7 +138,7 @@ class AllreduceRobust : public AllreduceBase {
ReConnectLinks("recover");
}
private:
protected:
// constant one byte out of band message to indicate error happening
// and mark for channel cleanup
static const char kOOBReset = 95;