add tracker print
This commit is contained in:
parent
6bf282c6c2
commit
6151899ce2
8
Makefile
8
Makefile
@ -7,8 +7,8 @@ export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../src
|
||||
BPATH=lib
|
||||
# objectives that makes up rabit library
|
||||
MPIOBJ= $(BPATH)/engine_mpi.o
|
||||
OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o
|
||||
ALIB= lib/librabit.a lib/librabit_mpi.a
|
||||
OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/engine_empty.o
|
||||
ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a
|
||||
|
||||
.PHONY: clean all
|
||||
|
||||
@ -18,8 +18,10 @@ $(BPATH)/allreduce_base.o: src/allreduce_base.cc src/*.h
|
||||
$(BPATH)/engine.o: src/engine.cc src/*.h
|
||||
$(BPATH)/allreduce_robust.o: src/allreduce_robust.cc src/*.h
|
||||
$(BPATH)/engine_mpi.o: src/engine_mpi.cc src/*.h
|
||||
$(BPATH)/engine_empty.o: src/engine_empty.cc src/*.h
|
||||
|
||||
lib/librabit.a: $(OBJ)
|
||||
lib/librabit.a: $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o
|
||||
lib/librabit_empty.a: $(BPATH)/engine_empty.o
|
||||
lib/librabit_mpi.a: $(MPIOBJ)
|
||||
|
||||
$(OBJ) :
|
||||
|
||||
@ -67,24 +67,21 @@ void AllreduceBase::Shutdown(void) {
|
||||
tree_links.plinks.clear();
|
||||
|
||||
if (tracker_uri == "NULL") return;
|
||||
int magic = kMagic;
|
||||
// notify tracker rank i have shutdown
|
||||
utils::TCPSocket tracker;
|
||||
tracker.Create();
|
||||
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
|
||||
utils::Socket::Error("Connect Tracker");
|
||||
}
|
||||
utils::Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 1");
|
||||
utils::Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 2");
|
||||
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
|
||||
|
||||
utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
|
||||
utils::Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3");
|
||||
tracker.SendStr(task_id);
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.SendStr(std::string("shutdown"));
|
||||
tracker.Close();
|
||||
utils::TCPSocket::Finalize();
|
||||
}
|
||||
void AllreduceBase::TrackerPrint(const std::string &msg) {
|
||||
if (tracker_uri == "NULL") {
|
||||
utils::Printf("%s", msg.c_str()); return;
|
||||
}
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.SendStr(std::string("print"));
|
||||
tracker.SendStr(msg);
|
||||
tracker.Close();
|
||||
}
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
* \param name parameter name
|
||||
@ -113,14 +110,10 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief connect to the tracker to fix the the missing links
|
||||
* this function is also used when the engine start up
|
||||
* \brief initialize connection to the tracker
|
||||
* \return a socket that initializes the connection
|
||||
*/
|
||||
void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
// single node mode
|
||||
if (tracker_uri == "NULL") {
|
||||
rank = 0; world_size = 1; return;
|
||||
}
|
||||
utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
|
||||
int magic = kMagic;
|
||||
// get information from tracker
|
||||
utils::TCPSocket tracker;
|
||||
@ -134,7 +127,20 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
|
||||
utils::Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3");
|
||||
tracker.SendStr(task_id);
|
||||
return tracker;
|
||||
}
|
||||
/*!
|
||||
* \brief connect to the tracker to fix the the missing links
|
||||
* this function is also used when the engine start up
|
||||
*/
|
||||
void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
// single node mode
|
||||
if (tracker_uri == "NULL") {
|
||||
rank = 0; world_size = 1; return;
|
||||
}
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.SendStr(std::string(cmd));
|
||||
|
||||
// the rank of previous link, next link in ring
|
||||
int prev_rank, next_rank;
|
||||
// the rank of neighbors
|
||||
|
||||
@ -45,6 +45,13 @@ class AllreduceBase : public IEngine {
|
||||
* \param val parameter value
|
||||
*/
|
||||
virtual void SetParam(const char *name, const char *val);
|
||||
/*!
|
||||
* \brief print the msg in the tracker,
|
||||
* this function can be used to communicate the information of the progress to
|
||||
* the user who monitors the tracker
|
||||
* \param msg message to be printed in the tracker
|
||||
*/
|
||||
virtual void TrackerPrint(const std::string &msg);
|
||||
/*! \brief get rank */
|
||||
virtual int GetRank(void) const {
|
||||
return rank;
|
||||
@ -279,6 +286,11 @@ class AllreduceBase : public IEngine {
|
||||
return plinks.size();
|
||||
}
|
||||
};
|
||||
/*!
|
||||
* \brief initialize connection to the tracker
|
||||
* \return a socket that initializes the connection
|
||||
*/
|
||||
utils::TCPSocket ConnectTracker(void) const;
|
||||
/*!
|
||||
* \brief connect to the tracker to fix the the missing links
|
||||
* this function is also used when the engine start up
|
||||
|
||||
@ -124,6 +124,13 @@ class IEngine {
|
||||
virtual int GetWorldSize(void) const = 0;
|
||||
/*! \brief get the host name of current node */
|
||||
virtual std::string GetHost(void) const = 0;
|
||||
/*!
|
||||
* \brief print the msg in the tracker,
|
||||
* this function can be used to communicate the information of the progress to
|
||||
* the user who monitors the tracker
|
||||
* \param msg message to be printed in the tracker
|
||||
*/
|
||||
virtual void TrackerPrint(const std::string &msg) = 0;
|
||||
};
|
||||
|
||||
/*! \brief intiialize the engine module */
|
||||
|
||||
90
src/engine_empty.cc
Normal file
90
src/engine_empty.cc
Normal file
@ -0,0 +1,90 @@
|
||||
/*!
|
||||
* \file engine_empty.cc
|
||||
* \brief this file provides a dummy implementation of engine that does nothing
|
||||
* this file provides a way to fall back to single node program without causing too many dependencies
|
||||
* This is usually NOT needed, use engine_mpi or engine for real distributed version
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#define NOMINMAX
|
||||
|
||||
#include "./engine.h"
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
/*! \brief EmptyEngine */
|
||||
class EmptyEngine : public IEngine {
|
||||
public:
|
||||
EmptyEngine(void) {
|
||||
version_number = 0;
|
||||
}
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
utils::Error("EmptyEngine:: Allreduce is not supported, use Allreduce_ instead");
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
|
||||
}
|
||||
virtual void InitAfterException(void) {
|
||||
utils::Error("EmptyEngine is not fault tolerant");
|
||||
}
|
||||
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model = NULL) {
|
||||
return 0;
|
||||
}
|
||||
virtual void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model = NULL) {
|
||||
version_number += 1;
|
||||
}
|
||||
virtual int VersionNumber(void) const {
|
||||
return version_number;
|
||||
}
|
||||
/*! \brief get rank of current node */
|
||||
virtual int GetRank(void) const {
|
||||
return 0;
|
||||
}
|
||||
/*! \brief get total number of */
|
||||
virtual int GetWorldSize(void) const {
|
||||
return 1;
|
||||
}
|
||||
/*! \brief get the host name of current node */
|
||||
virtual std::string GetHost(void) const {
|
||||
return std::string("");
|
||||
}
|
||||
virtual void TrackerPrint(const std::string &msg) {
|
||||
// simply print information into the tracker
|
||||
utils::Printf("%s", msg.c_str());
|
||||
}
|
||||
private:
|
||||
int version_number;
|
||||
};
|
||||
|
||||
// singleton sync manager
|
||||
EmptyEngine manager;
|
||||
|
||||
/*! \brief intiialize the synchronization module */
|
||||
void Init(int argc, char *argv[]) {
|
||||
}
|
||||
/*! \brief finalize syncrhonization module */
|
||||
void Finalize(void) {
|
||||
}
|
||||
|
||||
/*! \brief singleton method to get engine */
|
||||
IEngine *GetEngine(void) {
|
||||
return &manager;
|
||||
}
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
void Allreduce_(void *sendrecvbuf,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
IEngine::ReduceFunction red,
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
@ -8,6 +8,7 @@
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#define NOMINMAX
|
||||
#include <cstdio>
|
||||
#include "./engine.h"
|
||||
#include "./utils.h"
|
||||
#include <mpi.h>
|
||||
@ -61,7 +62,12 @@ class MPIEngine : public IEngine {
|
||||
name[len] = '\0';
|
||||
return std::string(name);
|
||||
}
|
||||
|
||||
virtual void TrackerPrint(const std::string &msg) {
|
||||
// simply print information into the tracker
|
||||
if (GetRank() == 0) {
|
||||
utils::Printf("%s", msg.c_str());
|
||||
}
|
||||
}
|
||||
private:
|
||||
int version_number;
|
||||
};
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#define RABIT_RABIT_INL_H
|
||||
// use engine for implementation
|
||||
#include "./engine.h"
|
||||
#include "./utils.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
@ -140,6 +141,21 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> pr
|
||||
}
|
||||
#endif // C++11
|
||||
|
||||
// print message to the tracker
|
||||
inline void TrackerPrint(const std::string &msg) {
|
||||
engine::GetEngine()->TrackerPrint(msg);
|
||||
}
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
inline void TrackerPrintf(const char *fmt, ...) {
|
||||
const int kPrintBuffer = 1 << 10;
|
||||
std::string msg(kPrintBuffer, '\0');
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||
va_end(args);
|
||||
TrackerPrint(msg);
|
||||
}
|
||||
#endif
|
||||
// load latest check point
|
||||
inline int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model) {
|
||||
|
||||
17
src/rabit.h
17
src/rabit.h
@ -47,6 +47,23 @@ inline int GetRank(void);
|
||||
inline int GetWorldSize(void);
|
||||
/*! \brief get name of processor */
|
||||
inline std::string GetProcessorName(void);
|
||||
/*!
|
||||
* \brief print the msg to the tracker,
|
||||
* this function can be used to communicate the information of the progress to
|
||||
* the user who monitors the tracker
|
||||
* \param msg, the message to be printed
|
||||
*/
|
||||
inline void TrackerPrint(const std::string &msg);
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
/*!
|
||||
* \brief print the msg to the tracker, this function may not be available
|
||||
* in very strict c++98 compilers, but is available most of the time
|
||||
* this function can be used to communicate the information of the progress to
|
||||
* the user who monitors the tracker
|
||||
* \param fmt the format string
|
||||
*/
|
||||
inline void TrackerPrintf(const char *fmt, ...);
|
||||
#endif
|
||||
/*!
|
||||
* \brief broadcast an memory region to all others from root
|
||||
* Example: int a = 1; Broadcast(&a, sizeof(a), root);
|
||||
|
||||
@ -106,15 +106,6 @@ inline void Printf(const char *fmt, ...) {
|
||||
va_end(args);
|
||||
HandlePrint(msg.c_str());
|
||||
}
|
||||
/*! \brief printf, print message to the console */
|
||||
inline void LogPrintf(const char *fmt, ...) {
|
||||
std::string msg(kPrintBuffer, '\0');
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||
va_end(args);
|
||||
HandleLogPrint(msg.c_str());
|
||||
}
|
||||
/*! \brief portable version of snprintf */
|
||||
inline int SPrintf(char *buf, size_t size, const char *fmt, ...) {
|
||||
va_list args;
|
||||
|
||||
@ -60,11 +60,11 @@ inline void PrintStats(const char *name, double tdiff, int n, int nrep, size_t s
|
||||
rabit::Allreduce<op::Sum>(&tsqr, 1);
|
||||
double tstd = sqrt(tsqr / nproc);
|
||||
if (rabit::GetRank() == 0) {
|
||||
utils::LogPrintf("%s: mean=%g, std=%g sec\n", name, tavg, tstd);
|
||||
rabit::TrackerPrintf("%s: mean=%g, std=%g sec\n", name, tavg, tstd);
|
||||
double ndata = n;
|
||||
ndata *= nrep * size;
|
||||
if (n != 0) {
|
||||
utils::LogPrintf("%s-speed: %g MB/sec\n", name, (ndata / tavg) / 1024 / 1024 );
|
||||
rabit::TrackerPrintf("%s-speed: %g MB/sec\n", name, (ndata / tavg) / 1024 / 1024 );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -147,22 +147,22 @@ int main(int argc, char *argv[]) {
|
||||
if (iter == 0) {
|
||||
model.InitModel(n, 1.0f);
|
||||
local.InitModel(n, 1.0f + rank);
|
||||
utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||
} else {
|
||||
utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||
}
|
||||
for (int r = iter; r < 3; ++r) {
|
||||
TestMax(&model, &local, ntrial, r);
|
||||
utils::LogPrintf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
|
||||
printf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
|
||||
int step = std::max(nproc / 3, 1);
|
||||
for (int i = 0; i < nproc; i += step) {
|
||||
TestBcast(n, i, ntrial, r);
|
||||
}
|
||||
utils::LogPrintf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
|
||||
printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
|
||||
TestSum(&model, &local, ntrial, r);
|
||||
utils::LogPrintf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
|
||||
printf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
|
||||
rabit::CheckPoint(&model, &local);
|
||||
utils::LogPrintf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
|
||||
printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
|
||||
}
|
||||
break;
|
||||
} catch (MockException &e) {
|
||||
|
||||
@ -136,22 +136,22 @@ int main(int argc, char *argv[]) {
|
||||
int iter = rabit::LoadCheckPoint(&model);
|
||||
if (iter == 0) {
|
||||
model.InitModel(n);
|
||||
utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||
} else {
|
||||
utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||
}
|
||||
for (int r = iter; r < 3; ++r) {
|
||||
TestMax(&model, ntrial, r);
|
||||
utils::LogPrintf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
|
||||
printf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
|
||||
int step = std::max(nproc / 3, 1);
|
||||
for (int i = 0; i < nproc; i += step) {
|
||||
TestBcast(n, i, ntrial, r);
|
||||
}
|
||||
utils::LogPrintf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
|
||||
printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
|
||||
TestSum(&model, ntrial, r);
|
||||
utils::LogPrintf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
|
||||
printf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
|
||||
rabit::CheckPoint(&model);
|
||||
utils::LogPrintf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
|
||||
printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
|
||||
}
|
||||
break;
|
||||
} catch (MockException &e) {
|
||||
|
||||
@ -188,6 +188,11 @@ class Tracker:
|
||||
rnext = (r + 1) % nslave
|
||||
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
|
||||
return ring_map
|
||||
def handle_print(self,slave, msg):
|
||||
sys.stdout.write(msg)
|
||||
def log_print(self, msg):
|
||||
sys.stderr.write(msg+'\n')
|
||||
|
||||
def accept_slaves(self, nslave):
|
||||
# set of nodes that finishs the job
|
||||
shutdown = {}
|
||||
@ -202,12 +207,16 @@ class Tracker:
|
||||
|
||||
while len(shutdown) != nslave:
|
||||
fd, s_addr = self.sock.accept()
|
||||
s = SlaveEntry(fd, s_addr)
|
||||
s = SlaveEntry(fd, s_addr)
|
||||
if s.cmd == 'print':
|
||||
msg = s.sock.recvstr()
|
||||
self.handle_print(s, msg)
|
||||
continue
|
||||
if s.cmd == 'shutdown':
|
||||
assert s.rank >= 0 and s.rank not in shutdown
|
||||
assert s.rank not in wait_conn
|
||||
shutdown[s.rank] = s
|
||||
print 'Recieve %s signal from %d' % (s.cmd, s.rank)
|
||||
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank))
|
||||
continue
|
||||
assert s.cmd == 'start' or s.cmd == 'recover'
|
||||
# lazily initialize the slaves
|
||||
@ -233,21 +242,16 @@ class Tracker:
|
||||
job_map[s.jobid] = rank
|
||||
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||
if s.cmd != 'start':
|
||||
print 'Recieve %s signal from %d' % (s.cmd, s.rank)
|
||||
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank))
|
||||
else:
|
||||
print 'Recieve %s signal from %s assign rank %d' % (s.cmd, s.host, s.rank)
|
||||
self.log_print('Recieve %s signal from %s assign rank %d' % (s.cmd, s.host, s.rank))
|
||||
if s.wait_accept > 0:
|
||||
wait_conn[rank] = s
|
||||
print 'All nodes finishes job'
|
||||
self.log_print('All nodes finishes job')
|
||||
|
||||
def mpi_submit(nslave, args):
|
||||
cmd = ' '.join(['mpirun -n %d' % nslave] + args)
|
||||
print cmd
|
||||
return subprocess.check_call(cmd, shell = True)
|
||||
|
||||
def submit(nslave, args, fun_submit = mpi_submit):
|
||||
def submit(nslave, args, fun_submit):
|
||||
master = Tracker()
|
||||
submit_thread = Thread(target = fun_submit, args = (nslave, args + master.slave_args()))
|
||||
submit_thread.start()
|
||||
master.accept_slaves(nslaves)
|
||||
master.accept_slaves(nslave)
|
||||
submit_thread.join()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user