add tracker print

This commit is contained in:
tqchen
2014-12-19 18:40:06 -08:00
parent 6bf282c6c2
commit 6151899ce2
13 changed files with 210 additions and 59 deletions

View File

@@ -67,24 +67,21 @@ void AllreduceBase::Shutdown(void) {
tree_links.plinks.clear();
if (tracker_uri == "NULL") return;
int magic = kMagic;
// notify tracker rank i have shutdown
utils::TCPSocket tracker;
tracker.Create();
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
utils::Socket::Error("Connect Tracker");
}
utils::Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 1");
utils::Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 2");
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
utils::Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3");
tracker.SendStr(task_id);
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("shutdown"));
tracker.Close();
utils::TCPSocket::Finalize();
}
void AllreduceBase::TrackerPrint(const std::string &msg) {
if (tracker_uri == "NULL") {
utils::Printf("%s", msg.c_str()); return;
}
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("print"));
tracker.SendStr(msg);
tracker.Close();
}
/*!
* \brief set parameters to the engine
* \param name parameter name
@@ -113,14 +110,10 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
}
}
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
void AllreduceBase::ReConnectLinks(const char *cmd) {
// single node mode
if (tracker_uri == "NULL") {
rank = 0; world_size = 1; return;
}
utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
int magic = kMagic;
// get information from tracker
utils::TCPSocket tracker;
@@ -134,7 +127,20 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
utils::Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3");
tracker.SendStr(task_id);
return tracker;
}
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
*/
void AllreduceBase::ReConnectLinks(const char *cmd) {
// single node mode
if (tracker_uri == "NULL") {
rank = 0; world_size = 1; return;
}
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string(cmd));
// the rank of previous link, next link in ring
int prev_rank, next_rank;
// the rank of neighbors

View File

@@ -45,6 +45,13 @@ class AllreduceBase : public IEngine {
* \param val parameter value
*/
virtual void SetParam(const char *name, const char *val);
/*!
* \brief print the msg in the tracker,
* this function can be used to communicate the information of the progress to
* the user who monitors the tracker
* \param msg message to be printed in the tracker
*/
virtual void TrackerPrint(const std::string &msg);
/*! \brief get rank */
virtual int GetRank(void) const {
return rank;
@@ -279,6 +286,11 @@ class AllreduceBase : public IEngine {
return plinks.size();
}
};
/*!
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
utils::TCPSocket ConnectTracker(void) const;
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up

View File

@@ -124,6 +124,13 @@ class IEngine {
virtual int GetWorldSize(void) const = 0;
/*! \brief get the host name of current node */
virtual std::string GetHost(void) const = 0;
/*!
* \brief print the msg in the tracker,
* this function can be used to communicate the information of the progress to
* the user who monitors the tracker
* \param msg message to be printed in the tracker
*/
virtual void TrackerPrint(const std::string &msg) = 0;
};
/*! \brief intiialize the engine module */

90
src/engine_empty.cc Normal file
View File

@@ -0,0 +1,90 @@
/*!
* \file engine_empty.cc
* \brief this file provides a dummy implementation of engine that does nothing
* this file provides a way to fall back to single node program without causing too many dependencies
* This is usually NOT needed, use engine_mpi or engine for real distributed version
* \author Tianqi Chen
*/
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include "./engine.h"
namespace rabit {
namespace engine {
/*! \brief EmptyEngine */
class EmptyEngine : public IEngine {
public:
EmptyEngine(void) {
version_number = 0;
}
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg) {
utils::Error("EmptyEngine:: Allreduce is not supported, use Allreduce_ instead");
}
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
}
virtual void InitAfterException(void) {
utils::Error("EmptyEngine is not fault tolerant");
}
virtual int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model = NULL) {
return 0;
}
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL) {
version_number += 1;
}
virtual int VersionNumber(void) const {
return version_number;
}
/*! \brief get rank of current node */
virtual int GetRank(void) const {
return 0;
}
/*! \brief get total number of */
virtual int GetWorldSize(void) const {
return 1;
}
/*! \brief get the host name of current node */
virtual std::string GetHost(void) const {
return std::string("");
}
virtual void TrackerPrint(const std::string &msg) {
// simply print information into the tracker
utils::Printf("%s", msg.c_str());
}
private:
int version_number;
};
// singleton sync manager
EmptyEngine manager;
/*! \brief intiialize the synchronization module */
void Init(int argc, char *argv[]) {
}
/*! \brief finalize syncrhonization module */
void Finalize(void) {
}
/*! \brief singleton method to get engine */
IEngine *GetEngine(void) {
return &manager;
}
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf,
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
}
} // namespace engine
} // namespace rabit

View File

@@ -8,6 +8,7 @@
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include <cstdio>
#include "./engine.h"
#include "./utils.h"
#include <mpi.h>
@@ -61,7 +62,12 @@ class MPIEngine : public IEngine {
name[len] = '\0';
return std::string(name);
}
virtual void TrackerPrint(const std::string &msg) {
// simply print information into the tracker
if (GetRank() == 0) {
utils::Printf("%s", msg.c_str());
}
}
private:
int version_number;
};

View File

@@ -8,6 +8,7 @@
#define RABIT_RABIT_INL_H
// use engine for implementation
#include "./engine.h"
#include "./utils.h"
namespace rabit {
namespace engine {
@@ -140,6 +141,21 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> pr
}
#endif // C++11
// print message to the tracker
inline void TrackerPrint(const std::string &msg) {
engine::GetEngine()->TrackerPrint(msg);
}
#ifndef RABIT_STRICT_CXX98_
inline void TrackerPrintf(const char *fmt, ...) {
const int kPrintBuffer = 1 << 10;
std::string msg(kPrintBuffer, '\0');
va_list args;
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
TrackerPrint(msg);
}
#endif
// load latest check point
inline int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model) {

View File

@@ -47,6 +47,23 @@ inline int GetRank(void);
inline int GetWorldSize(void);
/*! \brief get name of processor */
inline std::string GetProcessorName(void);
/*!
* \brief print the msg to the tracker,
* this function can be used to communicate the information of the progress to
* the user who monitors the tracker
* \param msg, the message to be printed
*/
inline void TrackerPrint(const std::string &msg);
#ifndef RABIT_STRICT_CXX98_
/*!
* \brief print the msg to the tracker, this function may not be available
* in very strict c++98 compilers, but is available most of the time
* this function can be used to communicate the information of the progress to
* the user who monitors the tracker
* \param fmt the format string
*/
inline void TrackerPrintf(const char *fmt, ...);
#endif
/*!
* \brief broadcast an memory region to all others from root
* Example: int a = 1; Broadcast(&a, sizeof(a), root);

View File

@@ -106,15 +106,6 @@ inline void Printf(const char *fmt, ...) {
va_end(args);
HandlePrint(msg.c_str());
}
/*! \brief printf, print message to the console */
inline void LogPrintf(const char *fmt, ...) {
std::string msg(kPrintBuffer, '\0');
va_list args;
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
HandleLogPrint(msg.c_str());
}
/*! \brief portable version of snprintf */
inline int SPrintf(char *buf, size_t size, const char *fmt, ...) {
va_list args;