diff --git a/include/rabit.h b/include/rabit.h index eb2834b30..7e3b88cdf 100644 --- a/include/rabit.h +++ b/include/rabit.h @@ -65,9 +65,8 @@ inline int GetRank(void); /*! \brief gets total number of processes */ inline int GetWorldSize(void); /*! \brief whether rabit env is in distributed mode */ -inline bool IsDistributed(void) { - return GetWorldSize() != 1; -} +inline bool IsDistributed(void); + /*! \brief gets processor's name */ inline std::string GetProcessorName(void); /*! diff --git a/include/rabit/engine.h b/include/rabit/engine.h index 668b4fcef..e0395cdcd 100644 --- a/include/rabit/engine.h +++ b/include/rabit/engine.h @@ -145,6 +145,8 @@ class IEngine { virtual int GetRank(void) const = 0; /*! \brief gets total number of nodes */ virtual int GetWorldSize(void) const = 0; + /*! \brief whether we run in distribted mode */ + virtual bool IsDistributed(void) const = 0; /*! \brief gets the host name of the current node */ virtual std::string GetHost(void) const = 0; /*! diff --git a/include/rabit/rabit-inl.h b/include/rabit/rabit-inl.h index e0a14f4ad..21d15d9e1 100644 --- a/include/rabit/rabit-inl.h +++ b/include/rabit/rabit-inl.h @@ -107,6 +107,10 @@ inline int GetRank(void) { inline int GetWorldSize(void) { return engine::GetEngine()->GetWorldSize(); } +// whether rabit is distributed +inline bool IsDistributed(void) { + return engine::GetEngine()->IsDistributed(); +} // get the name of current processor inline std::string GetProcessorName(void) { return engine::GetEngine()->GetHost(); diff --git a/src/allreduce_base.h b/src/allreduce_base.h index 00dc60754..41b9f35f2 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -63,6 +63,10 @@ class AllreduceBase : public IEngine { if (world_size == -1) return 1; return world_size; } + /*! \brief whether is distributed or not */ + virtual bool IsDistributed(void) const { + return tracker_uri == "NULL"; + } /*! \brief get rank */ virtual std::string GetHost(void) const { return host_uri; diff --git a/src/engine_empty.cc b/src/engine_empty.cc index a29a35b0c..35e2a07d1 100644 --- a/src/engine_empty.cc +++ b/src/engine_empty.cc @@ -56,6 +56,10 @@ class EmptyEngine : public IEngine { virtual int GetWorldSize(void) const { return 1; } + /*! \brief whether it is distributed */ + virtual bool IsDistributed(void) const { + return false; + } /*! \brief get the host name of current node */ virtual std::string GetHost(void) const { return std::string(""); diff --git a/src/engine_mpi.cc b/src/engine_mpi.cc index 829051231..c434e71d0 100644 --- a/src/engine_mpi.cc +++ b/src/engine_mpi.cc @@ -59,6 +59,10 @@ class MPIEngine : public IEngine { virtual int GetWorldSize(void) const { return MPI::COMM_WORLD.Get_size(); } + /*! \brief whether it is distributed */ + virtual bool IsDistributed(void) const { + return true; + } /*! \brief get the host name of current node */ virtual std::string GetHost(void) const { int len;