add tracker print
This commit is contained in:
@@ -67,24 +67,21 @@ void AllreduceBase::Shutdown(void) {
|
||||
tree_links.plinks.clear();
|
||||
|
||||
if (tracker_uri == "NULL") return;
|
||||
int magic = kMagic;
|
||||
// notify tracker rank i have shutdown
|
||||
utils::TCPSocket tracker;
|
||||
tracker.Create();
|
||||
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
|
||||
utils::Socket::Error("Connect Tracker");
|
||||
}
|
||||
utils::Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 1");
|
||||
utils::Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 2");
|
||||
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
|
||||
|
||||
utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
|
||||
utils::Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3");
|
||||
tracker.SendStr(task_id);
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.SendStr(std::string("shutdown"));
|
||||
tracker.Close();
|
||||
utils::TCPSocket::Finalize();
|
||||
}
|
||||
void AllreduceBase::TrackerPrint(const std::string &msg) {
|
||||
if (tracker_uri == "NULL") {
|
||||
utils::Printf("%s", msg.c_str()); return;
|
||||
}
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.SendStr(std::string("print"));
|
||||
tracker.SendStr(msg);
|
||||
tracker.Close();
|
||||
}
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
* \param name parameter name
|
||||
@@ -113,14 +110,10 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief connect to the tracker to fix the the missing links
|
||||
* this function is also used when the engine start up
|
||||
* \brief initialize connection to the tracker
|
||||
* \return a socket that initializes the connection
|
||||
*/
|
||||
void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
// single node mode
|
||||
if (tracker_uri == "NULL") {
|
||||
rank = 0; world_size = 1; return;
|
||||
}
|
||||
utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
|
||||
int magic = kMagic;
|
||||
// get information from tracker
|
||||
utils::TCPSocket tracker;
|
||||
@@ -134,7 +127,20 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
|
||||
utils::Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3");
|
||||
tracker.SendStr(task_id);
|
||||
return tracker;
|
||||
}
|
||||
/*!
|
||||
* \brief connect to the tracker to fix the the missing links
|
||||
* this function is also used when the engine start up
|
||||
*/
|
||||
void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
// single node mode
|
||||
if (tracker_uri == "NULL") {
|
||||
rank = 0; world_size = 1; return;
|
||||
}
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.SendStr(std::string(cmd));
|
||||
|
||||
// the rank of previous link, next link in ring
|
||||
int prev_rank, next_rank;
|
||||
// the rank of neighbors
|
||||
|
||||
@@ -45,6 +45,13 @@ class AllreduceBase : public IEngine {
|
||||
* \param val parameter value
|
||||
*/
|
||||
virtual void SetParam(const char *name, const char *val);
|
||||
/*!
|
||||
* \brief print the msg in the tracker,
|
||||
* this function can be used to communicate the information of the progress to
|
||||
* the user who monitors the tracker
|
||||
* \param msg message to be printed in the tracker
|
||||
*/
|
||||
virtual void TrackerPrint(const std::string &msg);
|
||||
/*! \brief get rank */
|
||||
virtual int GetRank(void) const {
|
||||
return rank;
|
||||
@@ -279,6 +286,11 @@ class AllreduceBase : public IEngine {
|
||||
return plinks.size();
|
||||
}
|
||||
};
|
||||
/*!
|
||||
* \brief initialize connection to the tracker
|
||||
* \return a socket that initializes the connection
|
||||
*/
|
||||
utils::TCPSocket ConnectTracker(void) const;
|
||||
/*!
|
||||
* \brief connect to the tracker to fix the the missing links
|
||||
* this function is also used when the engine start up
|
||||
|
||||
@@ -124,6 +124,13 @@ class IEngine {
|
||||
virtual int GetWorldSize(void) const = 0;
|
||||
/*! \brief get the host name of current node */
|
||||
virtual std::string GetHost(void) const = 0;
|
||||
/*!
|
||||
* \brief print the msg in the tracker,
|
||||
* this function can be used to communicate the information of the progress to
|
||||
* the user who monitors the tracker
|
||||
* \param msg message to be printed in the tracker
|
||||
*/
|
||||
virtual void TrackerPrint(const std::string &msg) = 0;
|
||||
};
|
||||
|
||||
/*! \brief intiialize the engine module */
|
||||
|
||||
90
src/engine_empty.cc
Normal file
90
src/engine_empty.cc
Normal file
@@ -0,0 +1,90 @@
|
||||
/*!
|
||||
* \file engine_empty.cc
|
||||
* \brief this file provides a dummy implementation of engine that does nothing
|
||||
* this file provides a way to fall back to single node program without causing too many dependencies
|
||||
* This is usually NOT needed, use engine_mpi or engine for real distributed version
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#define NOMINMAX
|
||||
|
||||
#include "./engine.h"
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
/*! \brief EmptyEngine */
|
||||
class EmptyEngine : public IEngine {
|
||||
public:
|
||||
EmptyEngine(void) {
|
||||
version_number = 0;
|
||||
}
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
utils::Error("EmptyEngine:: Allreduce is not supported, use Allreduce_ instead");
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
|
||||
}
|
||||
virtual void InitAfterException(void) {
|
||||
utils::Error("EmptyEngine is not fault tolerant");
|
||||
}
|
||||
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model = NULL) {
|
||||
return 0;
|
||||
}
|
||||
virtual void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model = NULL) {
|
||||
version_number += 1;
|
||||
}
|
||||
virtual int VersionNumber(void) const {
|
||||
return version_number;
|
||||
}
|
||||
/*! \brief get rank of current node */
|
||||
virtual int GetRank(void) const {
|
||||
return 0;
|
||||
}
|
||||
/*! \brief get total number of */
|
||||
virtual int GetWorldSize(void) const {
|
||||
return 1;
|
||||
}
|
||||
/*! \brief get the host name of current node */
|
||||
virtual std::string GetHost(void) const {
|
||||
return std::string("");
|
||||
}
|
||||
virtual void TrackerPrint(const std::string &msg) {
|
||||
// simply print information into the tracker
|
||||
utils::Printf("%s", msg.c_str());
|
||||
}
|
||||
private:
|
||||
int version_number;
|
||||
};
|
||||
|
||||
// singleton sync manager
|
||||
EmptyEngine manager;
|
||||
|
||||
/*! \brief intiialize the synchronization module */
|
||||
void Init(int argc, char *argv[]) {
|
||||
}
|
||||
/*! \brief finalize syncrhonization module */
|
||||
void Finalize(void) {
|
||||
}
|
||||
|
||||
/*! \brief singleton method to get engine */
|
||||
IEngine *GetEngine(void) {
|
||||
return &manager;
|
||||
}
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
void Allreduce_(void *sendrecvbuf,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
IEngine::ReduceFunction red,
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
@@ -8,6 +8,7 @@
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#define NOMINMAX
|
||||
#include <cstdio>
|
||||
#include "./engine.h"
|
||||
#include "./utils.h"
|
||||
#include <mpi.h>
|
||||
@@ -61,7 +62,12 @@ class MPIEngine : public IEngine {
|
||||
name[len] = '\0';
|
||||
return std::string(name);
|
||||
}
|
||||
|
||||
virtual void TrackerPrint(const std::string &msg) {
|
||||
// simply print information into the tracker
|
||||
if (GetRank() == 0) {
|
||||
utils::Printf("%s", msg.c_str());
|
||||
}
|
||||
}
|
||||
private:
|
||||
int version_number;
|
||||
};
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#define RABIT_RABIT_INL_H
|
||||
// use engine for implementation
|
||||
#include "./engine.h"
|
||||
#include "./utils.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
@@ -140,6 +141,21 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> pr
|
||||
}
|
||||
#endif // C++11
|
||||
|
||||
// print message to the tracker
|
||||
inline void TrackerPrint(const std::string &msg) {
|
||||
engine::GetEngine()->TrackerPrint(msg);
|
||||
}
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
inline void TrackerPrintf(const char *fmt, ...) {
|
||||
const int kPrintBuffer = 1 << 10;
|
||||
std::string msg(kPrintBuffer, '\0');
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||
va_end(args);
|
||||
TrackerPrint(msg);
|
||||
}
|
||||
#endif
|
||||
// load latest check point
|
||||
inline int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model) {
|
||||
|
||||
17
src/rabit.h
17
src/rabit.h
@@ -47,6 +47,23 @@ inline int GetRank(void);
|
||||
inline int GetWorldSize(void);
|
||||
/*! \brief get name of processor */
|
||||
inline std::string GetProcessorName(void);
|
||||
/*!
|
||||
* \brief print the msg to the tracker,
|
||||
* this function can be used to communicate the information of the progress to
|
||||
* the user who monitors the tracker
|
||||
* \param msg, the message to be printed
|
||||
*/
|
||||
inline void TrackerPrint(const std::string &msg);
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
/*!
|
||||
* \brief print the msg to the tracker, this function may not be available
|
||||
* in very strict c++98 compilers, but is available most of the time
|
||||
* this function can be used to communicate the information of the progress to
|
||||
* the user who monitors the tracker
|
||||
* \param fmt the format string
|
||||
*/
|
||||
inline void TrackerPrintf(const char *fmt, ...);
|
||||
#endif
|
||||
/*!
|
||||
* \brief broadcast an memory region to all others from root
|
||||
* Example: int a = 1; Broadcast(&a, sizeof(a), root);
|
||||
|
||||
@@ -106,15 +106,6 @@ inline void Printf(const char *fmt, ...) {
|
||||
va_end(args);
|
||||
HandlePrint(msg.c_str());
|
||||
}
|
||||
/*! \brief printf, print message to the console */
|
||||
inline void LogPrintf(const char *fmt, ...) {
|
||||
std::string msg(kPrintBuffer, '\0');
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||
va_end(args);
|
||||
HandleLogPrint(msg.c_str());
|
||||
}
|
||||
/*! \brief portable version of snprintf */
|
||||
inline int SPrintf(char *buf, size_t size, const char *fmt, ...) {
|
||||
va_list args;
|
||||
|
||||
Reference in New Issue
Block a user