add tracker print

This commit is contained in:
tqchen 2014-12-19 18:40:06 -08:00
parent 6bf282c6c2
commit 6151899ce2
13 changed files with 210 additions and 59 deletions

View File

@ -7,8 +7,8 @@ export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../src
BPATH=lib
# objectives that makes up rabit library
MPIOBJ= $(BPATH)/engine_mpi.o
OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o
ALIB= lib/librabit.a lib/librabit_mpi.a
OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/engine_empty.o
ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a
.PHONY: clean all
@ -18,8 +18,10 @@ $(BPATH)/allreduce_base.o: src/allreduce_base.cc src/*.h
$(BPATH)/engine.o: src/engine.cc src/*.h
$(BPATH)/allreduce_robust.o: src/allreduce_robust.cc src/*.h
$(BPATH)/engine_mpi.o: src/engine_mpi.cc src/*.h
$(BPATH)/engine_empty.o: src/engine_empty.cc src/*.h
lib/librabit.a: $(OBJ)
lib/librabit.a: $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o
lib/librabit_empty.a: $(BPATH)/engine_empty.o
lib/librabit_mpi.a: $(MPIOBJ)
$(OBJ) :

View File

@ -67,24 +67,21 @@ void AllreduceBase::Shutdown(void) {
tree_links.plinks.clear();
if (tracker_uri == "NULL") return;
int magic = kMagic;
// notify tracker rank i have shutdown
utils::TCPSocket tracker;
tracker.Create();
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
utils::Socket::Error("Connect Tracker");
}
utils::Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 1");
utils::Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 2");
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
utils::Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3");
tracker.SendStr(task_id);
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("shutdown"));
tracker.Close();
utils::TCPSocket::Finalize();
}
void AllreduceBase::TrackerPrint(const std::string &msg) {
if (tracker_uri == "NULL") {
utils::Printf("%s", msg.c_str()); return;
}
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("print"));
tracker.SendStr(msg);
tracker.Close();
}
/*!
* \brief set parameters to the engine
* \param name parameter name
@ -113,14 +110,10 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
}
}
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
void AllreduceBase::ReConnectLinks(const char *cmd) {
// single node mode
if (tracker_uri == "NULL") {
rank = 0; world_size = 1; return;
}
utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
int magic = kMagic;
// get information from tracker
utils::TCPSocket tracker;
@ -134,7 +127,20 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
utils::Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3");
tracker.SendStr(task_id);
return tracker;
}
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
*/
void AllreduceBase::ReConnectLinks(const char *cmd) {
// single node mode
if (tracker_uri == "NULL") {
rank = 0; world_size = 1; return;
}
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string(cmd));
// the rank of previous link, next link in ring
int prev_rank, next_rank;
// the rank of neighbors

View File

@ -45,6 +45,13 @@ class AllreduceBase : public IEngine {
* \param val parameter value
*/
virtual void SetParam(const char *name, const char *val);
/*!
* \brief print the msg in the tracker,
* this function can be used to communicate the information of the progress to
* the user who monitors the tracker
* \param msg message to be printed in the tracker
*/
virtual void TrackerPrint(const std::string &msg);
/*! \brief get rank */
virtual int GetRank(void) const {
return rank;
@ -279,6 +286,11 @@ class AllreduceBase : public IEngine {
return plinks.size();
}
};
/*!
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
utils::TCPSocket ConnectTracker(void) const;
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up

View File

@ -124,6 +124,13 @@ class IEngine {
virtual int GetWorldSize(void) const = 0;
/*! \brief get the host name of current node */
virtual std::string GetHost(void) const = 0;
/*!
* \brief print the msg in the tracker,
* this function can be used to communicate the information of the progress to
* the user who monitors the tracker
* \param msg message to be printed in the tracker
*/
virtual void TrackerPrint(const std::string &msg) = 0;
};
/*! \brief intiialize the engine module */

90
src/engine_empty.cc Normal file
View File

@ -0,0 +1,90 @@
/*!
* \file engine_empty.cc
* \brief this file provides a dummy implementation of engine that does nothing
* this file provides a way to fall back to single node program without causing too many dependencies
* This is usually NOT needed, use engine_mpi or engine for real distributed version
* \author Tianqi Chen
*/
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include "./engine.h"
namespace rabit {
namespace engine {
/*! \brief EmptyEngine */
class EmptyEngine : public IEngine {
public:
EmptyEngine(void) {
version_number = 0;
}
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg) {
utils::Error("EmptyEngine:: Allreduce is not supported, use Allreduce_ instead");
}
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
}
virtual void InitAfterException(void) {
utils::Error("EmptyEngine is not fault tolerant");
}
virtual int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model = NULL) {
return 0;
}
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL) {
version_number += 1;
}
virtual int VersionNumber(void) const {
return version_number;
}
/*! \brief get rank of current node */
virtual int GetRank(void) const {
return 0;
}
/*! \brief get total number of */
virtual int GetWorldSize(void) const {
return 1;
}
/*! \brief get the host name of current node */
virtual std::string GetHost(void) const {
return std::string("");
}
virtual void TrackerPrint(const std::string &msg) {
// simply print information into the tracker
utils::Printf("%s", msg.c_str());
}
private:
int version_number;
};
// singleton sync manager
EmptyEngine manager;
/*! \brief intiialize the synchronization module */
void Init(int argc, char *argv[]) {
}
/*! \brief finalize syncrhonization module */
void Finalize(void) {
}
/*! \brief singleton method to get engine */
IEngine *GetEngine(void) {
return &manager;
}
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf,
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
}
} // namespace engine
} // namespace rabit

View File

@ -8,6 +8,7 @@
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include <cstdio>
#include "./engine.h"
#include "./utils.h"
#include <mpi.h>
@ -61,7 +62,12 @@ class MPIEngine : public IEngine {
name[len] = '\0';
return std::string(name);
}
virtual void TrackerPrint(const std::string &msg) {
// simply print information into the tracker
if (GetRank() == 0) {
utils::Printf("%s", msg.c_str());
}
}
private:
int version_number;
};

View File

@ -8,6 +8,7 @@
#define RABIT_RABIT_INL_H
// use engine for implementation
#include "./engine.h"
#include "./utils.h"
namespace rabit {
namespace engine {
@ -140,6 +141,21 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> pr
}
#endif // C++11
// print message to the tracker
inline void TrackerPrint(const std::string &msg) {
engine::GetEngine()->TrackerPrint(msg);
}
#ifndef RABIT_STRICT_CXX98_
inline void TrackerPrintf(const char *fmt, ...) {
const int kPrintBuffer = 1 << 10;
std::string msg(kPrintBuffer, '\0');
va_list args;
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
TrackerPrint(msg);
}
#endif
// load latest check point
inline int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model) {

View File

@ -47,6 +47,23 @@ inline int GetRank(void);
inline int GetWorldSize(void);
/*! \brief get name of processor */
inline std::string GetProcessorName(void);
/*!
* \brief print the msg to the tracker,
* this function can be used to communicate the information of the progress to
* the user who monitors the tracker
* \param msg, the message to be printed
*/
inline void TrackerPrint(const std::string &msg);
#ifndef RABIT_STRICT_CXX98_
/*!
* \brief print the msg to the tracker, this function may not be available
* in very strict c++98 compilers, but is available most of the time
* this function can be used to communicate the information of the progress to
* the user who monitors the tracker
* \param fmt the format string
*/
inline void TrackerPrintf(const char *fmt, ...);
#endif
/*!
* \brief broadcast an memory region to all others from root
* Example: int a = 1; Broadcast(&a, sizeof(a), root);

View File

@ -106,15 +106,6 @@ inline void Printf(const char *fmt, ...) {
va_end(args);
HandlePrint(msg.c_str());
}
/*! \brief printf, print message to the console */
inline void LogPrintf(const char *fmt, ...) {
std::string msg(kPrintBuffer, '\0');
va_list args;
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
HandleLogPrint(msg.c_str());
}
/*! \brief portable version of snprintf */
inline int SPrintf(char *buf, size_t size, const char *fmt, ...) {
va_list args;

View File

@ -60,11 +60,11 @@ inline void PrintStats(const char *name, double tdiff, int n, int nrep, size_t s
rabit::Allreduce<op::Sum>(&tsqr, 1);
double tstd = sqrt(tsqr / nproc);
if (rabit::GetRank() == 0) {
utils::LogPrintf("%s: mean=%g, std=%g sec\n", name, tavg, tstd);
rabit::TrackerPrintf("%s: mean=%g, std=%g sec\n", name, tavg, tstd);
double ndata = n;
ndata *= nrep * size;
if (n != 0) {
utils::LogPrintf("%s-speed: %g MB/sec\n", name, (ndata / tavg) / 1024 / 1024 );
rabit::TrackerPrintf("%s-speed: %g MB/sec\n", name, (ndata / tavg) / 1024 / 1024 );
}
}
}

View File

@ -147,22 +147,22 @@ int main(int argc, char *argv[]) {
if (iter == 0) {
model.InitModel(n, 1.0f);
local.InitModel(n, 1.0f + rank);
utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
} else {
utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
}
for (int r = iter; r < 3; ++r) {
TestMax(&model, &local, ntrial, r);
utils::LogPrintf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
printf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
int step = std::max(nproc / 3, 1);
for (int i = 0; i < nproc; i += step) {
TestBcast(n, i, ntrial, r);
}
utils::LogPrintf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
TestSum(&model, &local, ntrial, r);
utils::LogPrintf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
printf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
rabit::CheckPoint(&model, &local);
utils::LogPrintf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
}
break;
} catch (MockException &e) {

View File

@ -136,22 +136,22 @@ int main(int argc, char *argv[]) {
int iter = rabit::LoadCheckPoint(&model);
if (iter == 0) {
model.InitModel(n);
utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
} else {
utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
}
for (int r = iter; r < 3; ++r) {
TestMax(&model, ntrial, r);
utils::LogPrintf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
printf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
int step = std::max(nproc / 3, 1);
for (int i = 0; i < nproc; i += step) {
TestBcast(n, i, ntrial, r);
}
utils::LogPrintf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
TestSum(&model, ntrial, r);
utils::LogPrintf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
printf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
rabit::CheckPoint(&model);
utils::LogPrintf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
}
break;
} catch (MockException &e) {

View File

@ -188,6 +188,11 @@ class Tracker:
rnext = (r + 1) % nslave
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
return ring_map
def handle_print(self,slave, msg):
sys.stdout.write(msg)
def log_print(self, msg):
sys.stderr.write(msg+'\n')
def accept_slaves(self, nslave):
# set of nodes that finishs the job
shutdown = {}
@ -202,12 +207,16 @@ class Tracker:
while len(shutdown) != nslave:
fd, s_addr = self.sock.accept()
s = SlaveEntry(fd, s_addr)
s = SlaveEntry(fd, s_addr)
if s.cmd == 'print':
msg = s.sock.recvstr()
self.handle_print(s, msg)
continue
if s.cmd == 'shutdown':
assert s.rank >= 0 and s.rank not in shutdown
assert s.rank not in wait_conn
shutdown[s.rank] = s
print 'Recieve %s signal from %d' % (s.cmd, s.rank)
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank))
continue
assert s.cmd == 'start' or s.cmd == 'recover'
# lazily initialize the slaves
@ -233,21 +242,16 @@ class Tracker:
job_map[s.jobid] = rank
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
if s.cmd != 'start':
print 'Recieve %s signal from %d' % (s.cmd, s.rank)
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank))
else:
print 'Recieve %s signal from %s assign rank %d' % (s.cmd, s.host, s.rank)
self.log_print('Recieve %s signal from %s assign rank %d' % (s.cmd, s.host, s.rank))
if s.wait_accept > 0:
wait_conn[rank] = s
print 'All nodes finishes job'
self.log_print('All nodes finishes job')
def mpi_submit(nslave, args):
cmd = ' '.join(['mpirun -n %d' % nslave] + args)
print cmd
return subprocess.check_call(cmd, shell = True)
def submit(nslave, args, fun_submit = mpi_submit):
def submit(nslave, args, fun_submit):
master = Tracker()
submit_thread = Thread(target = fun_submit, args = (nslave, args + master.slave_args()))
submit_thread.start()
master.accept_slaves(nslaves)
master.accept_slaves(nslave)
submit_thread.join()