From 6151899ce22d341511e7996ca09aa3235098a8b5 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 19 Dec 2014 18:40:06 -0800 Subject: [PATCH] add tracker print --- Makefile | 8 ++-- src/allreduce_base.cc | 46 ++++++++++--------- src/allreduce_base.h | 12 +++++ src/engine.h | 7 +++ src/engine_empty.cc | 90 +++++++++++++++++++++++++++++++++++++ src/engine_mpi.cc | 8 +++- src/rabit-inl.h | 16 +++++++ src/rabit.h | 17 +++++++ src/utils.h | 9 ---- test/speed_test.cpp | 4 +- test/test_local_recover.cpp | 12 ++--- test/test_model_recover.cpp | 12 ++--- tracker/rabit_tracker.py | 28 +++++++----- 13 files changed, 210 insertions(+), 59 deletions(-) create mode 100644 src/engine_empty.cc diff --git a/Makefile b/Makefile index bbed21c81..a591600a8 100644 --- a/Makefile +++ b/Makefile @@ -7,8 +7,8 @@ export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../src BPATH=lib # objectives that makes up rabit library MPIOBJ= $(BPATH)/engine_mpi.o -OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o -ALIB= lib/librabit.a lib/librabit_mpi.a +OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/engine_empty.o +ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a .PHONY: clean all @@ -18,8 +18,10 @@ $(BPATH)/allreduce_base.o: src/allreduce_base.cc src/*.h $(BPATH)/engine.o: src/engine.cc src/*.h $(BPATH)/allreduce_robust.o: src/allreduce_robust.cc src/*.h $(BPATH)/engine_mpi.o: src/engine_mpi.cc src/*.h +$(BPATH)/engine_empty.o: src/engine_empty.cc src/*.h -lib/librabit.a: $(OBJ) +lib/librabit.a: $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o +lib/librabit_empty.a: $(BPATH)/engine_empty.o lib/librabit_mpi.a: $(MPIOBJ) $(OBJ) : diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index d2ab14daa..b8b3ed0de 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -67,24 +67,21 @@ void AllreduceBase::Shutdown(void) { tree_links.plinks.clear(); if (tracker_uri == "NULL") return; - int magic = kMagic; // notify tracker rank i have shutdown - utils::TCPSocket tracker; - tracker.Create(); - if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) { - utils::Socket::Error("Connect Tracker"); - } - utils::Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 1"); - utils::Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 2"); - utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure"); - - utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3"); - utils::Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3"); - tracker.SendStr(task_id); + utils::TCPSocket tracker = this->ConnectTracker(); tracker.SendStr(std::string("shutdown")); tracker.Close(); utils::TCPSocket::Finalize(); } +void AllreduceBase::TrackerPrint(const std::string &msg) { + if (tracker_uri == "NULL") { + utils::Printf("%s", msg.c_str()); return; + } + utils::TCPSocket tracker = this->ConnectTracker(); + tracker.SendStr(std::string("print")); + tracker.SendStr(msg); + tracker.Close(); +} /*! * \brief set parameters to the engine * \param name parameter name @@ -113,14 +110,10 @@ void AllreduceBase::SetParam(const char *name, const char *val) { } } /*! - * \brief connect to the tracker to fix the the missing links - * this function is also used when the engine start up + * \brief initialize connection to the tracker + * \return a socket that initializes the connection */ -void AllreduceBase::ReConnectLinks(const char *cmd) { - // single node mode - if (tracker_uri == "NULL") { - rank = 0; world_size = 1; return; - } +utils::TCPSocket AllreduceBase::ConnectTracker(void) const { int magic = kMagic; // get information from tracker utils::TCPSocket tracker; @@ -134,7 +127,20 @@ void AllreduceBase::ReConnectLinks(const char *cmd) { utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3"); utils::Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3"); tracker.SendStr(task_id); + return tracker; +} +/*! + * \brief connect to the tracker to fix the the missing links + * this function is also used when the engine start up + */ +void AllreduceBase::ReConnectLinks(const char *cmd) { + // single node mode + if (tracker_uri == "NULL") { + rank = 0; world_size = 1; return; + } + utils::TCPSocket tracker = this->ConnectTracker(); tracker.SendStr(std::string(cmd)); + // the rank of previous link, next link in ring int prev_rank, next_rank; // the rank of neighbors diff --git a/src/allreduce_base.h b/src/allreduce_base.h index e313cab88..8bcc76781 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -45,6 +45,13 @@ class AllreduceBase : public IEngine { * \param val parameter value */ virtual void SetParam(const char *name, const char *val); + /*! + * \brief print the msg in the tracker, + * this function can be used to communicate the information of the progress to + * the user who monitors the tracker + * \param msg message to be printed in the tracker + */ + virtual void TrackerPrint(const std::string &msg); /*! \brief get rank */ virtual int GetRank(void) const { return rank; @@ -279,6 +286,11 @@ class AllreduceBase : public IEngine { return plinks.size(); } }; + /*! + * \brief initialize connection to the tracker + * \return a socket that initializes the connection + */ + utils::TCPSocket ConnectTracker(void) const; /*! * \brief connect to the tracker to fix the the missing links * this function is also used when the engine start up diff --git a/src/engine.h b/src/engine.h index 0700b2a95..891290ae0 100644 --- a/src/engine.h +++ b/src/engine.h @@ -124,6 +124,13 @@ class IEngine { virtual int GetWorldSize(void) const = 0; /*! \brief get the host name of current node */ virtual std::string GetHost(void) const = 0; + /*! + * \brief print the msg in the tracker, + * this function can be used to communicate the information of the progress to + * the user who monitors the tracker + * \param msg message to be printed in the tracker + */ + virtual void TrackerPrint(const std::string &msg) = 0; }; /*! \brief intiialize the engine module */ diff --git a/src/engine_empty.cc b/src/engine_empty.cc new file mode 100644 index 000000000..a2cbd2358 --- /dev/null +++ b/src/engine_empty.cc @@ -0,0 +1,90 @@ +/*! + * \file engine_empty.cc + * \brief this file provides a dummy implementation of engine that does nothing + * this file provides a way to fall back to single node program without causing too many dependencies + * This is usually NOT needed, use engine_mpi or engine for real distributed version + * \author Tianqi Chen + */ +#define _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_DEPRECATE +#define NOMINMAX + +#include "./engine.h" +namespace rabit { +namespace engine { +/*! \brief EmptyEngine */ +class EmptyEngine : public IEngine { + public: + EmptyEngine(void) { + version_number = 0; + } + virtual void Allreduce(void *sendrecvbuf_, + size_t type_nbytes, + size_t count, + ReduceFunction reducer, + PreprocFunction prepare_fun, + void *prepare_arg) { + utils::Error("EmptyEngine:: Allreduce is not supported, use Allreduce_ instead"); + } + virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) { + } + virtual void InitAfterException(void) { + utils::Error("EmptyEngine is not fault tolerant"); + } + virtual int LoadCheckPoint(ISerializable *global_model, + ISerializable *local_model = NULL) { + return 0; + } + virtual void CheckPoint(const ISerializable *global_model, + const ISerializable *local_model = NULL) { + version_number += 1; + } + virtual int VersionNumber(void) const { + return version_number; + } + /*! \brief get rank of current node */ + virtual int GetRank(void) const { + return 0; + } + /*! \brief get total number of */ + virtual int GetWorldSize(void) const { + return 1; + } + /*! \brief get the host name of current node */ + virtual std::string GetHost(void) const { + return std::string(""); + } + virtual void TrackerPrint(const std::string &msg) { + // simply print information into the tracker + utils::Printf("%s", msg.c_str()); + } + private: + int version_number; +}; + +// singleton sync manager +EmptyEngine manager; + +/*! \brief intiialize the synchronization module */ +void Init(int argc, char *argv[]) { +} +/*! \brief finalize syncrhonization module */ +void Finalize(void) { +} + +/*! \brief singleton method to get engine */ +IEngine *GetEngine(void) { + return &manager; +} +// perform in-place allreduce, on sendrecvbuf +void Allreduce_(void *sendrecvbuf, + size_t type_nbytes, + size_t count, + IEngine::ReduceFunction red, + mpi::DataType dtype, + mpi::OpType op, + IEngine::PreprocFunction prepare_fun, + void *prepare_arg) { +} +} // namespace engine +} // namespace rabit diff --git a/src/engine_mpi.cc b/src/engine_mpi.cc index 870c93fdb..7bf1fa2b6 100644 --- a/src/engine_mpi.cc +++ b/src/engine_mpi.cc @@ -8,6 +8,7 @@ #define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_DEPRECATE #define NOMINMAX +#include #include "./engine.h" #include "./utils.h" #include @@ -61,7 +62,12 @@ class MPIEngine : public IEngine { name[len] = '\0'; return std::string(name); } - + virtual void TrackerPrint(const std::string &msg) { + // simply print information into the tracker + if (GetRank() == 0) { + utils::Printf("%s", msg.c_str()); + } + } private: int version_number; }; diff --git a/src/rabit-inl.h b/src/rabit-inl.h index b6126f47d..8d681d32c 100644 --- a/src/rabit-inl.h +++ b/src/rabit-inl.h @@ -8,6 +8,7 @@ #define RABIT_RABIT_INL_H // use engine for implementation #include "./engine.h" +#include "./utils.h" namespace rabit { namespace engine { @@ -140,6 +141,21 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, std::function pr } #endif // C++11 +// print message to the tracker +inline void TrackerPrint(const std::string &msg) { + engine::GetEngine()->TrackerPrint(msg); +} +#ifndef RABIT_STRICT_CXX98_ +inline void TrackerPrintf(const char *fmt, ...) { + const int kPrintBuffer = 1 << 10; + std::string msg(kPrintBuffer, '\0'); + va_list args; + va_start(args, fmt); + vsnprintf(&msg[0], kPrintBuffer, fmt, args); + va_end(args); + TrackerPrint(msg); +} +#endif // load latest check point inline int LoadCheckPoint(ISerializable *global_model, ISerializable *local_model) { diff --git a/src/rabit.h b/src/rabit.h index cc65e62ae..bdf80e259 100644 --- a/src/rabit.h +++ b/src/rabit.h @@ -47,6 +47,23 @@ inline int GetRank(void); inline int GetWorldSize(void); /*! \brief get name of processor */ inline std::string GetProcessorName(void); +/*! + * \brief print the msg to the tracker, + * this function can be used to communicate the information of the progress to + * the user who monitors the tracker + * \param msg, the message to be printed + */ +inline void TrackerPrint(const std::string &msg); +#ifndef RABIT_STRICT_CXX98_ +/*! + * \brief print the msg to the tracker, this function may not be available + * in very strict c++98 compilers, but is available most of the time + * this function can be used to communicate the information of the progress to + * the user who monitors the tracker + * \param fmt the format string + */ +inline void TrackerPrintf(const char *fmt, ...); +#endif /*! * \brief broadcast an memory region to all others from root * Example: int a = 1; Broadcast(&a, sizeof(a), root); diff --git a/src/utils.h b/src/utils.h index e1b34fe2e..beae6589f 100644 --- a/src/utils.h +++ b/src/utils.h @@ -106,15 +106,6 @@ inline void Printf(const char *fmt, ...) { va_end(args); HandlePrint(msg.c_str()); } -/*! \brief printf, print message to the console */ -inline void LogPrintf(const char *fmt, ...) { - std::string msg(kPrintBuffer, '\0'); - va_list args; - va_start(args, fmt); - vsnprintf(&msg[0], kPrintBuffer, fmt, args); - va_end(args); - HandleLogPrint(msg.c_str()); -} /*! \brief portable version of snprintf */ inline int SPrintf(char *buf, size_t size, const char *fmt, ...) { va_list args; diff --git a/test/speed_test.cpp b/test/speed_test.cpp index 8f7fc68bf..e716731fd 100644 --- a/test/speed_test.cpp +++ b/test/speed_test.cpp @@ -60,11 +60,11 @@ inline void PrintStats(const char *name, double tdiff, int n, int nrep, size_t s rabit::Allreduce(&tsqr, 1); double tstd = sqrt(tsqr / nproc); if (rabit::GetRank() == 0) { - utils::LogPrintf("%s: mean=%g, std=%g sec\n", name, tavg, tstd); + rabit::TrackerPrintf("%s: mean=%g, std=%g sec\n", name, tavg, tstd); double ndata = n; ndata *= nrep * size; if (n != 0) { - utils::LogPrintf("%s-speed: %g MB/sec\n", name, (ndata / tavg) / 1024 / 1024 ); + rabit::TrackerPrintf("%s-speed: %g MB/sec\n", name, (ndata / tavg) / 1024 / 1024 ); } } } diff --git a/test/test_local_recover.cpp b/test/test_local_recover.cpp index b9b84f2d1..d473345b3 100644 --- a/test/test_local_recover.cpp +++ b/test/test_local_recover.cpp @@ -147,22 +147,22 @@ int main(int argc, char *argv[]) { if (iter == 0) { model.InitModel(n, 1.0f); local.InitModel(n, 1.0f + rank); - utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter); + printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter); } else { - utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter); + printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter); } for (int r = iter; r < 3; ++r) { TestMax(&model, &local, ntrial, r); - utils::LogPrintf("[%d] !!!TestMax pass, iter=%d\n", rank, r); + printf("[%d] !!!TestMax pass, iter=%d\n", rank, r); int step = std::max(nproc / 3, 1); for (int i = 0; i < nproc; i += step) { TestBcast(n, i, ntrial, r); } - utils::LogPrintf("[%d] !!!TestBcast pass, iter=%d\n", rank, r); + printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r); TestSum(&model, &local, ntrial, r); - utils::LogPrintf("[%d] !!!TestSum pass, iter=%d\n", rank, r); + printf("[%d] !!!TestSum pass, iter=%d\n", rank, r); rabit::CheckPoint(&model, &local); - utils::LogPrintf("[%d] !!!CheckPont pass, iter=%d\n", rank, r); + printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r); } break; } catch (MockException &e) { diff --git a/test/test_model_recover.cpp b/test/test_model_recover.cpp index aba107a85..117acef09 100644 --- a/test/test_model_recover.cpp +++ b/test/test_model_recover.cpp @@ -136,22 +136,22 @@ int main(int argc, char *argv[]) { int iter = rabit::LoadCheckPoint(&model); if (iter == 0) { model.InitModel(n); - utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter); + printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter); } else { - utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter); + printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter); } for (int r = iter; r < 3; ++r) { TestMax(&model, ntrial, r); - utils::LogPrintf("[%d] !!!TestMax pass, iter=%d\n", rank, r); + printf("[%d] !!!TestMax pass, iter=%d\n", rank, r); int step = std::max(nproc / 3, 1); for (int i = 0; i < nproc; i += step) { TestBcast(n, i, ntrial, r); } - utils::LogPrintf("[%d] !!!TestBcast pass, iter=%d\n", rank, r); + printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r); TestSum(&model, ntrial, r); - utils::LogPrintf("[%d] !!!TestSum pass, iter=%d\n", rank, r); + printf("[%d] !!!TestSum pass, iter=%d\n", rank, r); rabit::CheckPoint(&model); - utils::LogPrintf("[%d] !!!CheckPont pass, iter=%d\n", rank, r); + printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r); } break; } catch (MockException &e) { diff --git a/tracker/rabit_tracker.py b/tracker/rabit_tracker.py index 8e05b4b5a..0322edf5b 100644 --- a/tracker/rabit_tracker.py +++ b/tracker/rabit_tracker.py @@ -188,6 +188,11 @@ class Tracker: rnext = (r + 1) % nslave ring_map[rlst[r]] = (rlst[rprev], rlst[rnext]) return ring_map + def handle_print(self,slave, msg): + sys.stdout.write(msg) + def log_print(self, msg): + sys.stderr.write(msg+'\n') + def accept_slaves(self, nslave): # set of nodes that finishs the job shutdown = {} @@ -202,12 +207,16 @@ class Tracker: while len(shutdown) != nslave: fd, s_addr = self.sock.accept() - s = SlaveEntry(fd, s_addr) + s = SlaveEntry(fd, s_addr) + if s.cmd == 'print': + msg = s.sock.recvstr() + self.handle_print(s, msg) + continue if s.cmd == 'shutdown': assert s.rank >= 0 and s.rank not in shutdown assert s.rank not in wait_conn shutdown[s.rank] = s - print 'Recieve %s signal from %d' % (s.cmd, s.rank) + self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank)) continue assert s.cmd == 'start' or s.cmd == 'recover' # lazily initialize the slaves @@ -233,21 +242,16 @@ class Tracker: job_map[s.jobid] = rank s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) if s.cmd != 'start': - print 'Recieve %s signal from %d' % (s.cmd, s.rank) + self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank)) else: - print 'Recieve %s signal from %s assign rank %d' % (s.cmd, s.host, s.rank) + self.log_print('Recieve %s signal from %s assign rank %d' % (s.cmd, s.host, s.rank)) if s.wait_accept > 0: wait_conn[rank] = s - print 'All nodes finishes job' + self.log_print('All nodes finishes job') -def mpi_submit(nslave, args): - cmd = ' '.join(['mpirun -n %d' % nslave] + args) - print cmd - return subprocess.check_call(cmd, shell = True) - -def submit(nslave, args, fun_submit = mpi_submit): +def submit(nslave, args, fun_submit): master = Tracker() submit_thread = Thread(target = fun_submit, args = (nslave, args + master.slave_args())) submit_thread.start() - master.accept_slaves(nslaves) + master.accept_slaves(nslave) submit_thread.join()