add dmlc support
This commit is contained in:
parent
be2ff703bc
commit
8cb4c02165
@ -31,6 +31,7 @@ AllreduceBase::AllreduceBase(void) {
|
|||||||
// tracker URL
|
// tracker URL
|
||||||
task_id = "NULL";
|
task_id = "NULL";
|
||||||
err_link = NULL;
|
err_link = NULL;
|
||||||
|
dmlc_role = "worker";
|
||||||
this->SetParam("rabit_reduce_buffer", "256MB");
|
this->SetParam("rabit_reduce_buffer", "256MB");
|
||||||
// setup possible enviroment variable of intrest
|
// setup possible enviroment variable of intrest
|
||||||
env_vars.push_back("rabit_task_id");
|
env_vars.push_back("rabit_task_id");
|
||||||
@ -39,6 +40,12 @@ AllreduceBase::AllreduceBase(void) {
|
|||||||
env_vars.push_back("rabit_reduce_ring_mincount");
|
env_vars.push_back("rabit_reduce_ring_mincount");
|
||||||
env_vars.push_back("rabit_tracker_uri");
|
env_vars.push_back("rabit_tracker_uri");
|
||||||
env_vars.push_back("rabit_tracker_port");
|
env_vars.push_back("rabit_tracker_port");
|
||||||
|
// also include dmlc support direct variables
|
||||||
|
env_vars.push_back("DMLC_TASK_ID");
|
||||||
|
env_vars.push_back("DMLC_ROLE");
|
||||||
|
env_vars.push_back("DMLC_NUM_ATTEMPT");
|
||||||
|
env_vars.push_back("DMLC_TRACKER_URI");
|
||||||
|
env_vars.push_back("DMLC_TRACKER_PORT");
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialization function
|
// initialization function
|
||||||
@ -86,6 +93,10 @@ void AllreduceBase::Init(void) {
|
|||||||
this->SetParam("rabit_world_size", num_task);
|
this->SetParam("rabit_world_size", num_task);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (dmlc_role != "worker") {
|
||||||
|
fprintf(stderr, "Rabit Module currently only work with dmlc worker, quit this program by exit 0\n");
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
// clear the setting before start reconnection
|
// clear the setting before start reconnection
|
||||||
this->rank = -1;
|
this->rank = -1;
|
||||||
//---------------------
|
//---------------------
|
||||||
@ -150,6 +161,10 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val;
|
if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val;
|
||||||
if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val);
|
if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val);
|
||||||
if (!strcmp(name, "rabit_task_id")) task_id = val;
|
if (!strcmp(name, "rabit_task_id")) task_id = val;
|
||||||
|
if (!strcmp(name, "DMLC_TRACKER_URI")) tracker_uri = val;
|
||||||
|
if (!strcmp(name, "DMLC_TRACKER_PORT")) tracker_port = atoi(val);
|
||||||
|
if (!strcmp(name, "DMLC_TASK_ID")) task_id = val;
|
||||||
|
if (!strcmp(name, "DMLC_ROLE")) dmlc_role = val;
|
||||||
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
|
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
|
||||||
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val);
|
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val);
|
||||||
if (!strcmp(name, "rabit_reduce_ring_mincount")) {
|
if (!strcmp(name, "rabit_reduce_ring_mincount")) {
|
||||||
|
|||||||
@ -496,6 +496,8 @@ class AllreduceBase : public IEngine {
|
|||||||
std::string host_uri;
|
std::string host_uri;
|
||||||
// uri of tracker
|
// uri of tracker
|
||||||
std::string tracker_uri;
|
std::string tracker_uri;
|
||||||
|
// role in dmlc jobs
|
||||||
|
std::string dmlc_role;
|
||||||
// port of tracker address
|
// port of tracker address
|
||||||
int tracker_port;
|
int tracker_port;
|
||||||
// port of slave process
|
// port of slave process
|
||||||
|
|||||||
@ -31,6 +31,7 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
AllreduceRobust::SetParam(name, val);
|
AllreduceRobust::SetParam(name, val);
|
||||||
// additional parameters
|
// additional parameters
|
||||||
if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val);
|
if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val);
|
||||||
|
if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial = atoi(val);
|
||||||
if (!strcmp(name, "report_stats")) report_stats = atoi(val);
|
if (!strcmp(name, "report_stats")) report_stats = atoi(val);
|
||||||
if (!strcmp(name, "force_local")) force_local = atoi(val);
|
if (!strcmp(name, "force_local")) force_local = atoi(val);
|
||||||
if (!strcmp(name, "mock")) {
|
if (!strcmp(name, "mock")) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user