diff --git a/subtree/rabit/include/rabit.h b/subtree/rabit/include/rabit.h index eb2834b30..7e3b88cdf 100644 --- a/subtree/rabit/include/rabit.h +++ b/subtree/rabit/include/rabit.h @@ -65,9 +65,8 @@ inline int GetRank(void); /*! \brief gets total number of processes */ inline int GetWorldSize(void); /*! \brief whether rabit env is in distributed mode */ -inline bool IsDistributed(void) { - return GetWorldSize() != 1; -} +inline bool IsDistributed(void); + /*! \brief gets processor's name */ inline std::string GetProcessorName(void); /*! diff --git a/subtree/rabit/include/rabit/engine.h b/subtree/rabit/include/rabit/engine.h index 668b4fcef..e0395cdcd 100644 --- a/subtree/rabit/include/rabit/engine.h +++ b/subtree/rabit/include/rabit/engine.h @@ -145,6 +145,8 @@ class IEngine { virtual int GetRank(void) const = 0; /*! \brief gets total number of nodes */ virtual int GetWorldSize(void) const = 0; + /*! \brief whether we run in distribted mode */ + virtual bool IsDistributed(void) const = 0; /*! \brief gets the host name of the current node */ virtual std::string GetHost(void) const = 0; /*! diff --git a/subtree/rabit/include/rabit/rabit-inl.h b/subtree/rabit/include/rabit/rabit-inl.h index e0a14f4ad..21d15d9e1 100644 --- a/subtree/rabit/include/rabit/rabit-inl.h +++ b/subtree/rabit/include/rabit/rabit-inl.h @@ -107,6 +107,10 @@ inline int GetRank(void) { inline int GetWorldSize(void) { return engine::GetEngine()->GetWorldSize(); } +// whether rabit is distributed +inline bool IsDistributed(void) { + return engine::GetEngine()->IsDistributed(); +} // get the name of current processor inline std::string GetProcessorName(void) { return engine::GetEngine()->GetHost(); diff --git a/subtree/rabit/src/allreduce_base.h b/subtree/rabit/src/allreduce_base.h index 00dc60754..41b9f35f2 100644 --- a/subtree/rabit/src/allreduce_base.h +++ b/subtree/rabit/src/allreduce_base.h @@ -63,6 +63,10 @@ class AllreduceBase : public IEngine { if (world_size == -1) return 1; return world_size; } + /*! \brief whether is distributed or not */ + virtual bool IsDistributed(void) const { + return tracker_uri == "NULL"; + } /*! \brief get rank */ virtual std::string GetHost(void) const { return host_uri; diff --git a/subtree/rabit/src/engine_empty.cc b/subtree/rabit/src/engine_empty.cc index a29a35b0c..35e2a07d1 100644 --- a/subtree/rabit/src/engine_empty.cc +++ b/subtree/rabit/src/engine_empty.cc @@ -56,6 +56,10 @@ class EmptyEngine : public IEngine { virtual int GetWorldSize(void) const { return 1; } + /*! \brief whether it is distributed */ + virtual bool IsDistributed(void) const { + return false; + } /*! \brief get the host name of current node */ virtual std::string GetHost(void) const { return std::string(""); diff --git a/subtree/rabit/src/engine_mpi.cc b/subtree/rabit/src/engine_mpi.cc index 829051231..c434e71d0 100644 --- a/subtree/rabit/src/engine_mpi.cc +++ b/subtree/rabit/src/engine_mpi.cc @@ -59,6 +59,10 @@ class MPIEngine : public IEngine { virtual int GetWorldSize(void) const { return MPI::COMM_WORLD.Get_size(); } + /*! \brief whether it is distributed */ + virtual bool IsDistributed(void) const { + return true; + } /*! \brief get the host name of current node */ virtual std::string GetHost(void) const { int len; diff --git a/subtree/rabit/tracker/rabit_yarn.py b/subtree/rabit/tracker/rabit_yarn.py index 04074e618..81f590851 100755 --- a/subtree/rabit/tracker/rabit_yarn.py +++ b/subtree/rabit/tracker/rabit_yarn.py @@ -16,7 +16,7 @@ YARN_JAR_PATH = os.path.dirname(__file__) + '/../yarn/rabit-yarn.jar' if not os.path.exists(YARN_JAR_PATH): warnings.warn("cannot find \"%s\", I will try to run build" % YARN_JAR_PATH) - cmd = 'cd %;./build.sh' % os.path.dirname(__file__) + '/../yarn/' + cmd = 'cd %s;./build.sh' % (os.path.dirname(__file__) + '/../yarn/') print cmd subprocess.check_call(cmd, shell = True, env = os.environ) assert os.path.exists(YARN_JAR_PATH), "failed to build rabit-yarn.jar, try it manually" @@ -122,7 +122,8 @@ def submit_yarn(nworker, worker_args, worker_env): cmd += ' -jobname %s ' % args.jobname cmd += ' -tempdir %s ' % args.tempdir cmd += (' '.join(args.command + worker_args)) - print cmd + if args.verbose != 0: + print cmd subprocess.check_call(cmd, shell = True, env = env) tracker.submit(args.nworker, [], fun_submit = submit_yarn, verbose = args.verbose, hostIP = args.host_ip)