add tracker print
This commit is contained in:
parent
6bf282c6c2
commit
6151899ce2
8
Makefile
8
Makefile
@ -7,8 +7,8 @@ export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../src
|
|||||||
BPATH=lib
|
BPATH=lib
|
||||||
# objectives that makes up rabit library
|
# objectives that makes up rabit library
|
||||||
MPIOBJ= $(BPATH)/engine_mpi.o
|
MPIOBJ= $(BPATH)/engine_mpi.o
|
||||||
OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o
|
OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/engine_empty.o
|
||||||
ALIB= lib/librabit.a lib/librabit_mpi.a
|
ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a
|
||||||
|
|
||||||
.PHONY: clean all
|
.PHONY: clean all
|
||||||
|
|
||||||
@ -18,8 +18,10 @@ $(BPATH)/allreduce_base.o: src/allreduce_base.cc src/*.h
|
|||||||
$(BPATH)/engine.o: src/engine.cc src/*.h
|
$(BPATH)/engine.o: src/engine.cc src/*.h
|
||||||
$(BPATH)/allreduce_robust.o: src/allreduce_robust.cc src/*.h
|
$(BPATH)/allreduce_robust.o: src/allreduce_robust.cc src/*.h
|
||||||
$(BPATH)/engine_mpi.o: src/engine_mpi.cc src/*.h
|
$(BPATH)/engine_mpi.o: src/engine_mpi.cc src/*.h
|
||||||
|
$(BPATH)/engine_empty.o: src/engine_empty.cc src/*.h
|
||||||
|
|
||||||
lib/librabit.a: $(OBJ)
|
lib/librabit.a: $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o
|
||||||
|
lib/librabit_empty.a: $(BPATH)/engine_empty.o
|
||||||
lib/librabit_mpi.a: $(MPIOBJ)
|
lib/librabit_mpi.a: $(MPIOBJ)
|
||||||
|
|
||||||
$(OBJ) :
|
$(OBJ) :
|
||||||
|
|||||||
@ -67,24 +67,21 @@ void AllreduceBase::Shutdown(void) {
|
|||||||
tree_links.plinks.clear();
|
tree_links.plinks.clear();
|
||||||
|
|
||||||
if (tracker_uri == "NULL") return;
|
if (tracker_uri == "NULL") return;
|
||||||
int magic = kMagic;
|
|
||||||
// notify tracker rank i have shutdown
|
// notify tracker rank i have shutdown
|
||||||
utils::TCPSocket tracker;
|
utils::TCPSocket tracker = this->ConnectTracker();
|
||||||
tracker.Create();
|
|
||||||
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
|
|
||||||
utils::Socket::Error("Connect Tracker");
|
|
||||||
}
|
|
||||||
utils::Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 1");
|
|
||||||
utils::Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 2");
|
|
||||||
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
|
|
||||||
|
|
||||||
utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
|
|
||||||
utils::Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3");
|
|
||||||
tracker.SendStr(task_id);
|
|
||||||
tracker.SendStr(std::string("shutdown"));
|
tracker.SendStr(std::string("shutdown"));
|
||||||
tracker.Close();
|
tracker.Close();
|
||||||
utils::TCPSocket::Finalize();
|
utils::TCPSocket::Finalize();
|
||||||
}
|
}
|
||||||
|
void AllreduceBase::TrackerPrint(const std::string &msg) {
|
||||||
|
if (tracker_uri == "NULL") {
|
||||||
|
utils::Printf("%s", msg.c_str()); return;
|
||||||
|
}
|
||||||
|
utils::TCPSocket tracker = this->ConnectTracker();
|
||||||
|
tracker.SendStr(std::string("print"));
|
||||||
|
tracker.SendStr(msg);
|
||||||
|
tracker.Close();
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief set parameters to the engine
|
* \brief set parameters to the engine
|
||||||
* \param name parameter name
|
* \param name parameter name
|
||||||
@ -113,14 +110,10 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief connect to the tracker to fix the the missing links
|
* \brief initialize connection to the tracker
|
||||||
* this function is also used when the engine start up
|
* \return a socket that initializes the connection
|
||||||
*/
|
*/
|
||||||
void AllreduceBase::ReConnectLinks(const char *cmd) {
|
utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
|
||||||
// single node mode
|
|
||||||
if (tracker_uri == "NULL") {
|
|
||||||
rank = 0; world_size = 1; return;
|
|
||||||
}
|
|
||||||
int magic = kMagic;
|
int magic = kMagic;
|
||||||
// get information from tracker
|
// get information from tracker
|
||||||
utils::TCPSocket tracker;
|
utils::TCPSocket tracker;
|
||||||
@ -134,7 +127,20 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
|
utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
|
||||||
utils::Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3");
|
utils::Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3");
|
||||||
tracker.SendStr(task_id);
|
tracker.SendStr(task_id);
|
||||||
|
return tracker;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief connect to the tracker to fix the the missing links
|
||||||
|
* this function is also used when the engine start up
|
||||||
|
*/
|
||||||
|
void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||||
|
// single node mode
|
||||||
|
if (tracker_uri == "NULL") {
|
||||||
|
rank = 0; world_size = 1; return;
|
||||||
|
}
|
||||||
|
utils::TCPSocket tracker = this->ConnectTracker();
|
||||||
tracker.SendStr(std::string(cmd));
|
tracker.SendStr(std::string(cmd));
|
||||||
|
|
||||||
// the rank of previous link, next link in ring
|
// the rank of previous link, next link in ring
|
||||||
int prev_rank, next_rank;
|
int prev_rank, next_rank;
|
||||||
// the rank of neighbors
|
// the rank of neighbors
|
||||||
|
|||||||
@ -45,6 +45,13 @@ class AllreduceBase : public IEngine {
|
|||||||
* \param val parameter value
|
* \param val parameter value
|
||||||
*/
|
*/
|
||||||
virtual void SetParam(const char *name, const char *val);
|
virtual void SetParam(const char *name, const char *val);
|
||||||
|
/*!
|
||||||
|
* \brief print the msg in the tracker,
|
||||||
|
* this function can be used to communicate the information of the progress to
|
||||||
|
* the user who monitors the tracker
|
||||||
|
* \param msg message to be printed in the tracker
|
||||||
|
*/
|
||||||
|
virtual void TrackerPrint(const std::string &msg);
|
||||||
/*! \brief get rank */
|
/*! \brief get rank */
|
||||||
virtual int GetRank(void) const {
|
virtual int GetRank(void) const {
|
||||||
return rank;
|
return rank;
|
||||||
@ -279,6 +286,11 @@ class AllreduceBase : public IEngine {
|
|||||||
return plinks.size();
|
return plinks.size();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
/*!
|
||||||
|
* \brief initialize connection to the tracker
|
||||||
|
* \return a socket that initializes the connection
|
||||||
|
*/
|
||||||
|
utils::TCPSocket ConnectTracker(void) const;
|
||||||
/*!
|
/*!
|
||||||
* \brief connect to the tracker to fix the the missing links
|
* \brief connect to the tracker to fix the the missing links
|
||||||
* this function is also used when the engine start up
|
* this function is also used when the engine start up
|
||||||
|
|||||||
@ -124,6 +124,13 @@ class IEngine {
|
|||||||
virtual int GetWorldSize(void) const = 0;
|
virtual int GetWorldSize(void) const = 0;
|
||||||
/*! \brief get the host name of current node */
|
/*! \brief get the host name of current node */
|
||||||
virtual std::string GetHost(void) const = 0;
|
virtual std::string GetHost(void) const = 0;
|
||||||
|
/*!
|
||||||
|
* \brief print the msg in the tracker,
|
||||||
|
* this function can be used to communicate the information of the progress to
|
||||||
|
* the user who monitors the tracker
|
||||||
|
* \param msg message to be printed in the tracker
|
||||||
|
*/
|
||||||
|
virtual void TrackerPrint(const std::string &msg) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief intiialize the engine module */
|
/*! \brief intiialize the engine module */
|
||||||
|
|||||||
90
src/engine_empty.cc
Normal file
90
src/engine_empty.cc
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
/*!
|
||||||
|
* \file engine_empty.cc
|
||||||
|
* \brief this file provides a dummy implementation of engine that does nothing
|
||||||
|
* this file provides a way to fall back to single node program without causing too many dependencies
|
||||||
|
* This is usually NOT needed, use engine_mpi or engine for real distributed version
|
||||||
|
* \author Tianqi Chen
|
||||||
|
*/
|
||||||
|
#define _CRT_SECURE_NO_WARNINGS
|
||||||
|
#define _CRT_SECURE_NO_DEPRECATE
|
||||||
|
#define NOMINMAX
|
||||||
|
|
||||||
|
#include "./engine.h"
|
||||||
|
namespace rabit {
|
||||||
|
namespace engine {
|
||||||
|
/*! \brief EmptyEngine */
|
||||||
|
class EmptyEngine : public IEngine {
|
||||||
|
public:
|
||||||
|
EmptyEngine(void) {
|
||||||
|
version_number = 0;
|
||||||
|
}
|
||||||
|
virtual void Allreduce(void *sendrecvbuf_,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
ReduceFunction reducer,
|
||||||
|
PreprocFunction prepare_fun,
|
||||||
|
void *prepare_arg) {
|
||||||
|
utils::Error("EmptyEngine:: Allreduce is not supported, use Allreduce_ instead");
|
||||||
|
}
|
||||||
|
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
|
||||||
|
}
|
||||||
|
virtual void InitAfterException(void) {
|
||||||
|
utils::Error("EmptyEngine is not fault tolerant");
|
||||||
|
}
|
||||||
|
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||||
|
ISerializable *local_model = NULL) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
virtual void CheckPoint(const ISerializable *global_model,
|
||||||
|
const ISerializable *local_model = NULL) {
|
||||||
|
version_number += 1;
|
||||||
|
}
|
||||||
|
virtual int VersionNumber(void) const {
|
||||||
|
return version_number;
|
||||||
|
}
|
||||||
|
/*! \brief get rank of current node */
|
||||||
|
virtual int GetRank(void) const {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
/*! \brief get total number of */
|
||||||
|
virtual int GetWorldSize(void) const {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
/*! \brief get the host name of current node */
|
||||||
|
virtual std::string GetHost(void) const {
|
||||||
|
return std::string("");
|
||||||
|
}
|
||||||
|
virtual void TrackerPrint(const std::string &msg) {
|
||||||
|
// simply print information into the tracker
|
||||||
|
utils::Printf("%s", msg.c_str());
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
int version_number;
|
||||||
|
};
|
||||||
|
|
||||||
|
// singleton sync manager
|
||||||
|
EmptyEngine manager;
|
||||||
|
|
||||||
|
/*! \brief intiialize the synchronization module */
|
||||||
|
void Init(int argc, char *argv[]) {
|
||||||
|
}
|
||||||
|
/*! \brief finalize syncrhonization module */
|
||||||
|
void Finalize(void) {
|
||||||
|
}
|
||||||
|
|
||||||
|
/*! \brief singleton method to get engine */
|
||||||
|
IEngine *GetEngine(void) {
|
||||||
|
return &manager;
|
||||||
|
}
|
||||||
|
// perform in-place allreduce, on sendrecvbuf
|
||||||
|
void Allreduce_(void *sendrecvbuf,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
IEngine::ReduceFunction red,
|
||||||
|
mpi::DataType dtype,
|
||||||
|
mpi::OpType op,
|
||||||
|
IEngine::PreprocFunction prepare_fun,
|
||||||
|
void *prepare_arg) {
|
||||||
|
}
|
||||||
|
} // namespace engine
|
||||||
|
} // namespace rabit
|
||||||
@ -8,6 +8,7 @@
|
|||||||
#define _CRT_SECURE_NO_WARNINGS
|
#define _CRT_SECURE_NO_WARNINGS
|
||||||
#define _CRT_SECURE_NO_DEPRECATE
|
#define _CRT_SECURE_NO_DEPRECATE
|
||||||
#define NOMINMAX
|
#define NOMINMAX
|
||||||
|
#include <cstdio>
|
||||||
#include "./engine.h"
|
#include "./engine.h"
|
||||||
#include "./utils.h"
|
#include "./utils.h"
|
||||||
#include <mpi.h>
|
#include <mpi.h>
|
||||||
@ -61,7 +62,12 @@ class MPIEngine : public IEngine {
|
|||||||
name[len] = '\0';
|
name[len] = '\0';
|
||||||
return std::string(name);
|
return std::string(name);
|
||||||
}
|
}
|
||||||
|
virtual void TrackerPrint(const std::string &msg) {
|
||||||
|
// simply print information into the tracker
|
||||||
|
if (GetRank() == 0) {
|
||||||
|
utils::Printf("%s", msg.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
private:
|
private:
|
||||||
int version_number;
|
int version_number;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
#define RABIT_RABIT_INL_H
|
#define RABIT_RABIT_INL_H
|
||||||
// use engine for implementation
|
// use engine for implementation
|
||||||
#include "./engine.h"
|
#include "./engine.h"
|
||||||
|
#include "./utils.h"
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
@ -140,6 +141,21 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> pr
|
|||||||
}
|
}
|
||||||
#endif // C++11
|
#endif // C++11
|
||||||
|
|
||||||
|
// print message to the tracker
|
||||||
|
inline void TrackerPrint(const std::string &msg) {
|
||||||
|
engine::GetEngine()->TrackerPrint(msg);
|
||||||
|
}
|
||||||
|
#ifndef RABIT_STRICT_CXX98_
|
||||||
|
inline void TrackerPrintf(const char *fmt, ...) {
|
||||||
|
const int kPrintBuffer = 1 << 10;
|
||||||
|
std::string msg(kPrintBuffer, '\0');
|
||||||
|
va_list args;
|
||||||
|
va_start(args, fmt);
|
||||||
|
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||||
|
va_end(args);
|
||||||
|
TrackerPrint(msg);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
// load latest check point
|
// load latest check point
|
||||||
inline int LoadCheckPoint(ISerializable *global_model,
|
inline int LoadCheckPoint(ISerializable *global_model,
|
||||||
ISerializable *local_model) {
|
ISerializable *local_model) {
|
||||||
|
|||||||
17
src/rabit.h
17
src/rabit.h
@ -47,6 +47,23 @@ inline int GetRank(void);
|
|||||||
inline int GetWorldSize(void);
|
inline int GetWorldSize(void);
|
||||||
/*! \brief get name of processor */
|
/*! \brief get name of processor */
|
||||||
inline std::string GetProcessorName(void);
|
inline std::string GetProcessorName(void);
|
||||||
|
/*!
|
||||||
|
* \brief print the msg to the tracker,
|
||||||
|
* this function can be used to communicate the information of the progress to
|
||||||
|
* the user who monitors the tracker
|
||||||
|
* \param msg, the message to be printed
|
||||||
|
*/
|
||||||
|
inline void TrackerPrint(const std::string &msg);
|
||||||
|
#ifndef RABIT_STRICT_CXX98_
|
||||||
|
/*!
|
||||||
|
* \brief print the msg to the tracker, this function may not be available
|
||||||
|
* in very strict c++98 compilers, but is available most of the time
|
||||||
|
* this function can be used to communicate the information of the progress to
|
||||||
|
* the user who monitors the tracker
|
||||||
|
* \param fmt the format string
|
||||||
|
*/
|
||||||
|
inline void TrackerPrintf(const char *fmt, ...);
|
||||||
|
#endif
|
||||||
/*!
|
/*!
|
||||||
* \brief broadcast an memory region to all others from root
|
* \brief broadcast an memory region to all others from root
|
||||||
* Example: int a = 1; Broadcast(&a, sizeof(a), root);
|
* Example: int a = 1; Broadcast(&a, sizeof(a), root);
|
||||||
|
|||||||
@ -106,15 +106,6 @@ inline void Printf(const char *fmt, ...) {
|
|||||||
va_end(args);
|
va_end(args);
|
||||||
HandlePrint(msg.c_str());
|
HandlePrint(msg.c_str());
|
||||||
}
|
}
|
||||||
/*! \brief printf, print message to the console */
|
|
||||||
inline void LogPrintf(const char *fmt, ...) {
|
|
||||||
std::string msg(kPrintBuffer, '\0');
|
|
||||||
va_list args;
|
|
||||||
va_start(args, fmt);
|
|
||||||
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
|
||||||
va_end(args);
|
|
||||||
HandleLogPrint(msg.c_str());
|
|
||||||
}
|
|
||||||
/*! \brief portable version of snprintf */
|
/*! \brief portable version of snprintf */
|
||||||
inline int SPrintf(char *buf, size_t size, const char *fmt, ...) {
|
inline int SPrintf(char *buf, size_t size, const char *fmt, ...) {
|
||||||
va_list args;
|
va_list args;
|
||||||
|
|||||||
@ -60,11 +60,11 @@ inline void PrintStats(const char *name, double tdiff, int n, int nrep, size_t s
|
|||||||
rabit::Allreduce<op::Sum>(&tsqr, 1);
|
rabit::Allreduce<op::Sum>(&tsqr, 1);
|
||||||
double tstd = sqrt(tsqr / nproc);
|
double tstd = sqrt(tsqr / nproc);
|
||||||
if (rabit::GetRank() == 0) {
|
if (rabit::GetRank() == 0) {
|
||||||
utils::LogPrintf("%s: mean=%g, std=%g sec\n", name, tavg, tstd);
|
rabit::TrackerPrintf("%s: mean=%g, std=%g sec\n", name, tavg, tstd);
|
||||||
double ndata = n;
|
double ndata = n;
|
||||||
ndata *= nrep * size;
|
ndata *= nrep * size;
|
||||||
if (n != 0) {
|
if (n != 0) {
|
||||||
utils::LogPrintf("%s-speed: %g MB/sec\n", name, (ndata / tavg) / 1024 / 1024 );
|
rabit::TrackerPrintf("%s-speed: %g MB/sec\n", name, (ndata / tavg) / 1024 / 1024 );
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -147,22 +147,22 @@ int main(int argc, char *argv[]) {
|
|||||||
if (iter == 0) {
|
if (iter == 0) {
|
||||||
model.InitModel(n, 1.0f);
|
model.InitModel(n, 1.0f);
|
||||||
local.InitModel(n, 1.0f + rank);
|
local.InitModel(n, 1.0f + rank);
|
||||||
utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||||
} else {
|
} else {
|
||||||
utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||||
}
|
}
|
||||||
for (int r = iter; r < 3; ++r) {
|
for (int r = iter; r < 3; ++r) {
|
||||||
TestMax(&model, &local, ntrial, r);
|
TestMax(&model, &local, ntrial, r);
|
||||||
utils::LogPrintf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
|
printf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
|
||||||
int step = std::max(nproc / 3, 1);
|
int step = std::max(nproc / 3, 1);
|
||||||
for (int i = 0; i < nproc; i += step) {
|
for (int i = 0; i < nproc; i += step) {
|
||||||
TestBcast(n, i, ntrial, r);
|
TestBcast(n, i, ntrial, r);
|
||||||
}
|
}
|
||||||
utils::LogPrintf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
|
printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
|
||||||
TestSum(&model, &local, ntrial, r);
|
TestSum(&model, &local, ntrial, r);
|
||||||
utils::LogPrintf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
|
printf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
|
||||||
rabit::CheckPoint(&model, &local);
|
rabit::CheckPoint(&model, &local);
|
||||||
utils::LogPrintf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
|
printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
} catch (MockException &e) {
|
} catch (MockException &e) {
|
||||||
|
|||||||
@ -136,22 +136,22 @@ int main(int argc, char *argv[]) {
|
|||||||
int iter = rabit::LoadCheckPoint(&model);
|
int iter = rabit::LoadCheckPoint(&model);
|
||||||
if (iter == 0) {
|
if (iter == 0) {
|
||||||
model.InitModel(n);
|
model.InitModel(n);
|
||||||
utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||||
} else {
|
} else {
|
||||||
utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||||
}
|
}
|
||||||
for (int r = iter; r < 3; ++r) {
|
for (int r = iter; r < 3; ++r) {
|
||||||
TestMax(&model, ntrial, r);
|
TestMax(&model, ntrial, r);
|
||||||
utils::LogPrintf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
|
printf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
|
||||||
int step = std::max(nproc / 3, 1);
|
int step = std::max(nproc / 3, 1);
|
||||||
for (int i = 0; i < nproc; i += step) {
|
for (int i = 0; i < nproc; i += step) {
|
||||||
TestBcast(n, i, ntrial, r);
|
TestBcast(n, i, ntrial, r);
|
||||||
}
|
}
|
||||||
utils::LogPrintf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
|
printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
|
||||||
TestSum(&model, ntrial, r);
|
TestSum(&model, ntrial, r);
|
||||||
utils::LogPrintf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
|
printf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
|
||||||
rabit::CheckPoint(&model);
|
rabit::CheckPoint(&model);
|
||||||
utils::LogPrintf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
|
printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
} catch (MockException &e) {
|
} catch (MockException &e) {
|
||||||
|
|||||||
@ -188,6 +188,11 @@ class Tracker:
|
|||||||
rnext = (r + 1) % nslave
|
rnext = (r + 1) % nslave
|
||||||
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
|
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
|
||||||
return ring_map
|
return ring_map
|
||||||
|
def handle_print(self,slave, msg):
|
||||||
|
sys.stdout.write(msg)
|
||||||
|
def log_print(self, msg):
|
||||||
|
sys.stderr.write(msg+'\n')
|
||||||
|
|
||||||
def accept_slaves(self, nslave):
|
def accept_slaves(self, nslave):
|
||||||
# set of nodes that finishs the job
|
# set of nodes that finishs the job
|
||||||
shutdown = {}
|
shutdown = {}
|
||||||
@ -202,12 +207,16 @@ class Tracker:
|
|||||||
|
|
||||||
while len(shutdown) != nslave:
|
while len(shutdown) != nslave:
|
||||||
fd, s_addr = self.sock.accept()
|
fd, s_addr = self.sock.accept()
|
||||||
s = SlaveEntry(fd, s_addr)
|
s = SlaveEntry(fd, s_addr)
|
||||||
|
if s.cmd == 'print':
|
||||||
|
msg = s.sock.recvstr()
|
||||||
|
self.handle_print(s, msg)
|
||||||
|
continue
|
||||||
if s.cmd == 'shutdown':
|
if s.cmd == 'shutdown':
|
||||||
assert s.rank >= 0 and s.rank not in shutdown
|
assert s.rank >= 0 and s.rank not in shutdown
|
||||||
assert s.rank not in wait_conn
|
assert s.rank not in wait_conn
|
||||||
shutdown[s.rank] = s
|
shutdown[s.rank] = s
|
||||||
print 'Recieve %s signal from %d' % (s.cmd, s.rank)
|
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank))
|
||||||
continue
|
continue
|
||||||
assert s.cmd == 'start' or s.cmd == 'recover'
|
assert s.cmd == 'start' or s.cmd == 'recover'
|
||||||
# lazily initialize the slaves
|
# lazily initialize the slaves
|
||||||
@ -233,21 +242,16 @@ class Tracker:
|
|||||||
job_map[s.jobid] = rank
|
job_map[s.jobid] = rank
|
||||||
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||||
if s.cmd != 'start':
|
if s.cmd != 'start':
|
||||||
print 'Recieve %s signal from %d' % (s.cmd, s.rank)
|
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank))
|
||||||
else:
|
else:
|
||||||
print 'Recieve %s signal from %s assign rank %d' % (s.cmd, s.host, s.rank)
|
self.log_print('Recieve %s signal from %s assign rank %d' % (s.cmd, s.host, s.rank))
|
||||||
if s.wait_accept > 0:
|
if s.wait_accept > 0:
|
||||||
wait_conn[rank] = s
|
wait_conn[rank] = s
|
||||||
print 'All nodes finishes job'
|
self.log_print('All nodes finishes job')
|
||||||
|
|
||||||
def mpi_submit(nslave, args):
|
def submit(nslave, args, fun_submit):
|
||||||
cmd = ' '.join(['mpirun -n %d' % nslave] + args)
|
|
||||||
print cmd
|
|
||||||
return subprocess.check_call(cmd, shell = True)
|
|
||||||
|
|
||||||
def submit(nslave, args, fun_submit = mpi_submit):
|
|
||||||
master = Tracker()
|
master = Tracker()
|
||||||
submit_thread = Thread(target = fun_submit, args = (nslave, args + master.slave_args()))
|
submit_thread = Thread(target = fun_submit, args = (nslave, args + master.slave_args()))
|
||||||
submit_thread.start()
|
submit_thread.start()
|
||||||
master.accept_slaves(nslaves)
|
master.accept_slaves(nslave)
|
||||||
submit_thread.join()
|
submit_thread.join()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user