Merge commit '4c060df2f17405dc26dc65a77e412d5c2a23525a'

Conflicts:
	subtree/rabit/tracker/rabit_yarn.py
This commit is contained in:
tqchen 2015-03-09 14:45:23 -07:00
commit 8f7e9abf89
7 changed files with 23 additions and 5 deletions

View File

@ -65,9 +65,8 @@ inline int GetRank(void);
/*! \brief gets total number of processes */
inline int GetWorldSize(void);
/*! \brief whether rabit env is in distributed mode */
inline bool IsDistributed(void) {
return GetWorldSize() != 1;
}
inline bool IsDistributed(void);
/*! \brief gets processor's name */
inline std::string GetProcessorName(void);
/*!

View File

@ -145,6 +145,8 @@ class IEngine {
virtual int GetRank(void) const = 0;
/*! \brief gets total number of nodes */
virtual int GetWorldSize(void) const = 0;
/*! \brief whether we run in distribted mode */
virtual bool IsDistributed(void) const = 0;
/*! \brief gets the host name of the current node */
virtual std::string GetHost(void) const = 0;
/*!

View File

@ -107,6 +107,10 @@ inline int GetRank(void) {
inline int GetWorldSize(void) {
return engine::GetEngine()->GetWorldSize();
}
// whether rabit is distributed
inline bool IsDistributed(void) {
return engine::GetEngine()->IsDistributed();
}
// get the name of current processor
inline std::string GetProcessorName(void) {
return engine::GetEngine()->GetHost();

View File

@ -63,6 +63,10 @@ class AllreduceBase : public IEngine {
if (world_size == -1) return 1;
return world_size;
}
/*! \brief whether is distributed or not */
virtual bool IsDistributed(void) const {
return tracker_uri == "NULL";
}
/*! \brief get rank */
virtual std::string GetHost(void) const {
return host_uri;

View File

@ -56,6 +56,10 @@ class EmptyEngine : public IEngine {
virtual int GetWorldSize(void) const {
return 1;
}
/*! \brief whether it is distributed */
virtual bool IsDistributed(void) const {
return false;
}
/*! \brief get the host name of current node */
virtual std::string GetHost(void) const {
return std::string("");

View File

@ -59,6 +59,10 @@ class MPIEngine : public IEngine {
virtual int GetWorldSize(void) const {
return MPI::COMM_WORLD.Get_size();
}
/*! \brief whether it is distributed */
virtual bool IsDistributed(void) const {
return true;
}
/*! \brief get the host name of current node */
virtual std::string GetHost(void) const {
int len;

View File

@ -16,7 +16,7 @@ YARN_JAR_PATH = os.path.dirname(__file__) + '/../yarn/rabit-yarn.jar'
if not os.path.exists(YARN_JAR_PATH):
warnings.warn("cannot find \"%s\", I will try to run build" % YARN_JAR_PATH)
cmd = 'cd %;./build.sh' % os.path.dirname(__file__) + '/../yarn/'
cmd = 'cd %s;./build.sh' % (os.path.dirname(__file__) + '/../yarn/')
print cmd
subprocess.check_call(cmd, shell = True, env = os.environ)
assert os.path.exists(YARN_JAR_PATH), "failed to build rabit-yarn.jar, try it manually"
@ -122,7 +122,8 @@ def submit_yarn(nworker, worker_args, worker_env):
cmd += ' -jobname %s ' % args.jobname
cmd += ' -tempdir %s ' % args.tempdir
cmd += (' '.join(args.command + worker_args))
print cmd
if args.verbose != 0:
print cmd
subprocess.check_call(cmd, shell = True, env = env)
tracker.submit(args.nworker, [], fun_submit = submit_yarn, verbose = args.verbose, hostIP = args.host_ip)