From 8cb4c021655557a421b09c635d7042c472a239f8 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 28 Mar 2015 22:44:10 -0700 Subject: [PATCH] add dmlc support --- src/allreduce_base.cc | 15 +++++++++++++++ src/allreduce_base.h | 2 ++ src/allreduce_mock.h | 1 + 3 files changed, 18 insertions(+) diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 0235723a6..d0eff0425 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -31,6 +31,7 @@ AllreduceBase::AllreduceBase(void) { // tracker URL task_id = "NULL"; err_link = NULL; + dmlc_role = "worker"; this->SetParam("rabit_reduce_buffer", "256MB"); // setup possible enviroment variable of intrest env_vars.push_back("rabit_task_id"); @@ -39,6 +40,12 @@ AllreduceBase::AllreduceBase(void) { env_vars.push_back("rabit_reduce_ring_mincount"); env_vars.push_back("rabit_tracker_uri"); env_vars.push_back("rabit_tracker_port"); + // also include dmlc support direct variables + env_vars.push_back("DMLC_TASK_ID"); + env_vars.push_back("DMLC_ROLE"); + env_vars.push_back("DMLC_NUM_ATTEMPT"); + env_vars.push_back("DMLC_TRACKER_URI"); + env_vars.push_back("DMLC_TRACKER_PORT"); } // initialization function @@ -86,6 +93,10 @@ void AllreduceBase::Init(void) { this->SetParam("rabit_world_size", num_task); } } + if (dmlc_role != "worker") { + fprintf(stderr, "Rabit Module currently only work with dmlc worker, quit this program by exit 0\n"); + exit(0); + } // clear the setting before start reconnection this->rank = -1; //--------------------- @@ -150,6 +161,10 @@ void AllreduceBase::SetParam(const char *name, const char *val) { if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val; if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val); if (!strcmp(name, "rabit_task_id")) task_id = val; + if (!strcmp(name, "DMLC_TRACKER_URI")) tracker_uri = val; + if (!strcmp(name, "DMLC_TRACKER_PORT")) tracker_port = atoi(val); + if (!strcmp(name, "DMLC_TASK_ID")) task_id = val; + if (!strcmp(name, "DMLC_ROLE")) dmlc_role = val; if (!strcmp(name, "rabit_world_size")) world_size = atoi(val); if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val); if (!strcmp(name, "rabit_reduce_ring_mincount")) { diff --git a/src/allreduce_base.h b/src/allreduce_base.h index a9eafea39..690c27d8a 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -496,6 +496,8 @@ class AllreduceBase : public IEngine { std::string host_uri; // uri of tracker std::string tracker_uri; + // role in dmlc jobs + std::string dmlc_role; // port of tracker address int tracker_port; // port of slave process diff --git a/src/allreduce_mock.h b/src/allreduce_mock.h index 67f8d80dd..666acbeef 100644 --- a/src/allreduce_mock.h +++ b/src/allreduce_mock.h @@ -31,6 +31,7 @@ class AllreduceMock : public AllreduceRobust { AllreduceRobust::SetParam(name, val); // additional parameters if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val); + if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial = atoi(val); if (!strcmp(name, "report_stats")) report_stats = atoi(val); if (!strcmp(name, "force_local")) force_local = atoi(val); if (!strcmp(name, "mock")) {