From 4c060df2f17405dc26dc65a77e412d5c2a23525a Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 9 Mar 2015 14:44:42 -0700 Subject: [PATCH] Squashed 'subtree/rabit/' changes from 28ca7be..d558f6f d558f6f redefine distributed means c8efc01 more complicated yarn script git-subtree-dir: subtree/rabit git-subtree-split: d558f6f550d156f2ed2bde6e2e71d09c19a19aac --- include/rabit.h | 5 ++--- include/rabit/engine.h | 2 ++ include/rabit/rabit-inl.h | 4 ++++ src/allreduce_base.h | 4 ++++ src/engine_empty.cc | 4 ++++ src/engine_mpi.cc | 4 ++++ tracker/rabit_yarn.py | 11 +++++++++-- 7 files changed, 29 insertions(+), 5 deletions(-) diff --git a/include/rabit.h b/include/rabit.h index eb2834b30..7e3b88cdf 100644 --- a/include/rabit.h +++ b/include/rabit.h @@ -65,9 +65,8 @@ inline int GetRank(void); /*! \brief gets total number of processes */ inline int GetWorldSize(void); /*! \brief whether rabit env is in distributed mode */ -inline bool IsDistributed(void) { - return GetWorldSize() != 1; -} +inline bool IsDistributed(void); + /*! \brief gets processor's name */ inline std::string GetProcessorName(void); /*! diff --git a/include/rabit/engine.h b/include/rabit/engine.h index 668b4fcef..e0395cdcd 100644 --- a/include/rabit/engine.h +++ b/include/rabit/engine.h @@ -145,6 +145,8 @@ class IEngine { virtual int GetRank(void) const = 0; /*! \brief gets total number of nodes */ virtual int GetWorldSize(void) const = 0; + /*! \brief whether we run in distribted mode */ + virtual bool IsDistributed(void) const = 0; /*! \brief gets the host name of the current node */ virtual std::string GetHost(void) const = 0; /*! diff --git a/include/rabit/rabit-inl.h b/include/rabit/rabit-inl.h index e0a14f4ad..21d15d9e1 100644 --- a/include/rabit/rabit-inl.h +++ b/include/rabit/rabit-inl.h @@ -107,6 +107,10 @@ inline int GetRank(void) { inline int GetWorldSize(void) { return engine::GetEngine()->GetWorldSize(); } +// whether rabit is distributed +inline bool IsDistributed(void) { + return engine::GetEngine()->IsDistributed(); +} // get the name of current processor inline std::string GetProcessorName(void) { return engine::GetEngine()->GetHost(); diff --git a/src/allreduce_base.h b/src/allreduce_base.h index 00dc60754..41b9f35f2 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -63,6 +63,10 @@ class AllreduceBase : public IEngine { if (world_size == -1) return 1; return world_size; } + /*! \brief whether is distributed or not */ + virtual bool IsDistributed(void) const { + return tracker_uri == "NULL"; + } /*! \brief get rank */ virtual std::string GetHost(void) const { return host_uri; diff --git a/src/engine_empty.cc b/src/engine_empty.cc index a29a35b0c..35e2a07d1 100644 --- a/src/engine_empty.cc +++ b/src/engine_empty.cc @@ -56,6 +56,10 @@ class EmptyEngine : public IEngine { virtual int GetWorldSize(void) const { return 1; } + /*! \brief whether it is distributed */ + virtual bool IsDistributed(void) const { + return false; + } /*! \brief get the host name of current node */ virtual std::string GetHost(void) const { return std::string(""); diff --git a/src/engine_mpi.cc b/src/engine_mpi.cc index 829051231..c434e71d0 100644 --- a/src/engine_mpi.cc +++ b/src/engine_mpi.cc @@ -59,6 +59,10 @@ class MPIEngine : public IEngine { virtual int GetWorldSize(void) const { return MPI::COMM_WORLD.Get_size(); } + /*! \brief whether it is distributed */ + virtual bool IsDistributed(void) const { + return true; + } /*! \brief get the host name of current node */ virtual std::string GetHost(void) const { int len; diff --git a/tracker/rabit_yarn.py b/tracker/rabit_yarn.py index 3a4937278..81f590851 100755 --- a/tracker/rabit_yarn.py +++ b/tracker/rabit_yarn.py @@ -14,7 +14,13 @@ import rabit_tracker as tracker WRAPPER_PATH = os.path.dirname(__file__) + '/../wrapper' YARN_JAR_PATH = os.path.dirname(__file__) + '/../yarn/rabit-yarn.jar' -assert os.path.exists(YARN_JAR_PATH), ("cannot find \"%s\", please run build.sh on the yarn folder" % YARN_JAR_PATH) +if not os.path.exists(YARN_JAR_PATH): + warnings.warn("cannot find \"%s\", I will try to run build" % YARN_JAR_PATH) + cmd = 'cd %s;./build.sh' % (os.path.dirname(__file__) + '/../yarn/') + print cmd + subprocess.check_call(cmd, shell = True, env = os.environ) + assert os.path.exists(YARN_JAR_PATH), "failed to build rabit-yarn.jar, try it manually" + hadoop_binary = 'hadoop' # code hadoop_home = os.getenv('HADOOP_HOME') @@ -116,7 +122,8 @@ def submit_yarn(nworker, worker_args, worker_env): cmd += ' -jobname %s ' % args.jobname cmd += ' -tempdir %s ' % args.tempdir cmd += (' '.join(args.command + worker_args)) - print cmd + if args.verbose != 0: + print cmd subprocess.check_call(cmd, shell = True, env = env) tracker.submit(args.nworker, [], fun_submit = submit_yarn, verbose = args.verbose, hostIP = args.host_ip)