ok
This commit is contained in:
parent
e7a22792ac
commit
ab278513ab
@ -7,6 +7,7 @@
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#define NOMINMAX
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include "./allreduce_base.h"
|
||||
|
||||
@ -21,13 +22,26 @@ AllreduceBase::AllreduceBase(void) {
|
||||
nport_trial = 1000;
|
||||
rank = -1;
|
||||
world_size = 1;
|
||||
hadoop_mode = 0;
|
||||
version_number = 0;
|
||||
job_id = "NULL";
|
||||
task_id = "NULL";
|
||||
this->SetParam("reduce_buffer", "256MB");
|
||||
}
|
||||
|
||||
// initialization function
|
||||
void AllreduceBase::Init(void) {
|
||||
{
|
||||
// handling for hadoop
|
||||
const char *task_id = getenv("mapred_task_id");
|
||||
if (hadoop_mode != 0) {
|
||||
utils::Check(task_id != NULL, "hadoop_mode is set but cannot find mapred_task_id");
|
||||
}
|
||||
if (task_id != NULL) {
|
||||
this->SetParam("task_id", task_id);
|
||||
this->SetParam("hadoop_mode", "1");
|
||||
}
|
||||
}
|
||||
// start socket
|
||||
utils::Socket::Startup();
|
||||
utils::Assert(links.size() == 0, "can only call Init once");
|
||||
this->host_uri = utils::SockAddr::GetHostName();
|
||||
@ -54,7 +68,7 @@ void AllreduceBase::Shutdown(void) {
|
||||
utils::Check(magic == kMagic, "sync::Invalid master message, init failure");
|
||||
|
||||
utils::Assert(master.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
|
||||
master.SendStr(job_id);
|
||||
master.SendStr(task_id);
|
||||
master.SendStr(std::string("shutdown"));
|
||||
master.Close();
|
||||
utils::TCPSocket::Finalize();
|
||||
@ -67,7 +81,8 @@ void AllreduceBase::Shutdown(void) {
|
||||
void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
if (!strcmp(name, "master_uri")) master_uri = val;
|
||||
if (!strcmp(name, "master_port")) master_port = atoi(val);
|
||||
if (!strcmp(name, "job_id")) job_id = val;
|
||||
if (!strcmp(name, "task_id")) task_id = val;
|
||||
if (!strcmp(name, "hadoop_mode")) hadoop_mode = atoi(val);
|
||||
if (!strcmp(name, "reduce_buffer")) {
|
||||
char unit;
|
||||
unsigned long amount;
|
||||
@ -104,7 +119,7 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
utils::Assert(master.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 2");
|
||||
utils::Check(magic == kMagic, "sync::Invalid master message, init failure");
|
||||
utils::Assert(master.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
|
||||
master.SendStr(job_id);
|
||||
master.SendStr(task_id);
|
||||
master.SendStr(std::string(cmd));
|
||||
{// get new ranks
|
||||
int newrank;
|
||||
|
||||
@ -285,6 +285,8 @@ class AllreduceBase : public IEngine {
|
||||
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
|
||||
//---- data structure related to model ----
|
||||
int version_number;
|
||||
// whether the job is running in hadoop
|
||||
int hadoop_mode;
|
||||
//---- local data related to link ----
|
||||
// index of parent link, can be -1, meaning this is root of the tree
|
||||
int parent_index;
|
||||
@ -297,7 +299,7 @@ class AllreduceBase : public IEngine {
|
||||
//----- meta information-----
|
||||
// unique identifier of the possible job this process is doing
|
||||
// used to assign ranks, optional, default to NULL
|
||||
std::string job_id;
|
||||
std::string task_id;
|
||||
// uri of current host, to be set by Init
|
||||
std::string host_uri;
|
||||
// uri of master
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user