Make rabit library thread local
This commit is contained in:
parent
aeb4008606
commit
be50e7b632
@ -2,25 +2,25 @@ export CC = gcc
|
||||
export CXX = g++
|
||||
export MPICXX = mpicxx
|
||||
export LDFLAGS= -pthread -lm -L../lib
|
||||
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../include
|
||||
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -fopenmp -I../include
|
||||
|
||||
.PHONY: clean all lib libmpi
|
||||
BIN = basic.rabit broadcast.rabit
|
||||
MOCKBIN= lazy_allreduce.mock
|
||||
|
||||
all: $(BIN)
|
||||
basic.rabit: basic.cc lib
|
||||
broadcast.rabit: broadcast.cc lib
|
||||
lazy_allreduce.mock: lazy_allreduce.cc lib
|
||||
basic.rabit: basic.cc lib ../lib/librabit.a
|
||||
broadcast.rabit: broadcast.cc lib ../lib/librabit.a
|
||||
lazy_allreduce.mock: lazy_allreduce.cc lib ../lib/librabit.a
|
||||
|
||||
$(BIN) :
|
||||
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit
|
||||
$(BIN) :
|
||||
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS)
|
||||
|
||||
$(MOCKBIN) :
|
||||
$(MOCKBIN) :
|
||||
$(CXX) $(CFLAGS) -std=c++11 -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit_mock
|
||||
|
||||
$(OBJ) :
|
||||
$(OBJ) :
|
||||
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
||||
|
||||
clean:
|
||||
$(RM) $(OBJ) $(BIN) $(MOCKBIN) *~ ../src/*~
|
||||
$(RM) $(OBJ) $(BIN) $(MOCKBIN) *~ ../src/*~
|
||||
|
||||
@ -8,7 +8,7 @@
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#include <vector>
|
||||
#include <rabit.h>
|
||||
#include <rabit/rabit.h>
|
||||
using namespace rabit;
|
||||
int main(int argc, char *argv[]) {
|
||||
int N = 3;
|
||||
@ -19,7 +19,7 @@ int main(int argc, char *argv[]) {
|
||||
rabit::Init(argc, argv);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
a[i] = rabit::GetRank() + i;
|
||||
}
|
||||
}
|
||||
printf("@node[%d] before-allreduce: a={%d, %d, %d}\n",
|
||||
rabit::GetRank(), a[0], a[1], a[2]);
|
||||
// allreduce take max of each elements in all processes
|
||||
@ -29,7 +29,7 @@ int main(int argc, char *argv[]) {
|
||||
// second allreduce that sums everything up
|
||||
Allreduce<op::Sum>(&a[0], N);
|
||||
printf("@node[%d] after-allreduce-sum: a={%d, %d, %d}\n",
|
||||
rabit::GetRank(), a[0], a[1], a[2]);
|
||||
rabit::GetRank(), a[0], a[1], a[2]);
|
||||
rabit::Finalize();
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
#include <rabit.h>
|
||||
#include <rabit/rabit.h>
|
||||
using namespace rabit;
|
||||
const int N = 3;
|
||||
int main(int argc, char *argv[]) {
|
||||
|
||||
@ -5,7 +5,8 @@
|
||||
*
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <rabit.h>
|
||||
#include <rabit/rabit.h>
|
||||
|
||||
using namespace rabit;
|
||||
const int N = 3;
|
||||
int main(int argc, char *argv[]) {
|
||||
@ -16,18 +17,18 @@ int main(int argc, char *argv[]) {
|
||||
printf("@node[%d] run prepare function\n", rabit::GetRank());
|
||||
for (int i = 0; i < N; ++i) {
|
||||
a[i] = rabit::GetRank() + i;
|
||||
}
|
||||
}
|
||||
};
|
||||
printf("@node[%d] before-allreduce: a={%d, %d, %d}\n",
|
||||
rabit::GetRank(), a[0], a[1], a[2]);
|
||||
// allreduce take max of each elements in all processes
|
||||
Allreduce<op::Max>(&a[0], N, prepare);
|
||||
Allreduce<op::Max>(&a[0], N, prepare);
|
||||
printf("@node[%d] after-allreduce-sum: a={%d, %d, %d}\n",
|
||||
rabit::GetRank(), a[0], a[1], a[2]);
|
||||
rabit::GetRank(), a[0], a[1], a[2]);
|
||||
// rum second allreduce
|
||||
Allreduce<op::Sum>(&a[0], N);
|
||||
printf("@node[%d] after-allreduce-max: a={%d, %d, %d}\n",
|
||||
rabit::GetRank(), a[0], a[1], a[2]);
|
||||
rabit::GetRank(), a[0], a[1], a[2]);
|
||||
rabit::Finalize();
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -51,7 +51,7 @@ AllreduceBase::AllreduceBase(void) {
|
||||
}
|
||||
|
||||
// initialization function
|
||||
void AllreduceBase::Init(void) {
|
||||
void AllreduceBase::Init(int argc, char* argv[]) {
|
||||
// setup from enviroment variables
|
||||
// handler to get variables from env
|
||||
for (size_t i = 0; i < env_vars.size(); ++i) {
|
||||
@ -60,6 +60,14 @@ void AllreduceBase::Init(void) {
|
||||
this->SetParam(env_vars[i].c_str(), value);
|
||||
}
|
||||
}
|
||||
// pass in arguments override env variable.
|
||||
for (int i = 0; i < argc; ++i) {
|
||||
char name[256], val[256];
|
||||
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
|
||||
this->SetParam(name, val);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// handling for hadoop
|
||||
const char *task_id = getenv("mapred_tip_id");
|
||||
|
||||
@ -38,7 +38,7 @@ class AllreduceBase : public IEngine {
|
||||
AllreduceBase(void);
|
||||
virtual ~AllreduceBase(void) {}
|
||||
// initialize the manager
|
||||
virtual void Init(void);
|
||||
virtual void Init(int argc, char* argv[]);
|
||||
// shutdown the engine
|
||||
virtual void Shutdown(void);
|
||||
/*!
|
||||
|
||||
@ -31,8 +31,8 @@ AllreduceRobust::AllreduceRobust(void) {
|
||||
env_vars.push_back("rabit_global_replica");
|
||||
env_vars.push_back("rabit_local_replica");
|
||||
}
|
||||
void AllreduceRobust::Init(void) {
|
||||
AllreduceBase::Init();
|
||||
void AllreduceRobust::Init(int argc, char* argv[]) {
|
||||
AllreduceBase::Init(argc, argv);
|
||||
result_buffer_round = std::max(world_size / num_global_replica, 1);
|
||||
}
|
||||
/*! \brief shutdown the engine */
|
||||
|
||||
@ -24,7 +24,7 @@ class AllreduceRobust : public AllreduceBase {
|
||||
AllreduceRobust(void);
|
||||
virtual ~AllreduceRobust(void) {}
|
||||
// initialize the manager
|
||||
virtual void Init(void);
|
||||
virtual void Init(int argc, char* argv[]);
|
||||
/*! \brief shutdown the engine */
|
||||
virtual void Shutdown(void);
|
||||
/*!
|
||||
|
||||
@ -10,42 +10,72 @@
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#define NOMINMAX
|
||||
|
||||
#include <memory>
|
||||
#include "../include/rabit/internal/engine.h"
|
||||
#include "./allreduce_base.h"
|
||||
#include "./allreduce_robust.h"
|
||||
#include "./thread_local.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
// singleton sync manager
|
||||
#ifndef RABIT_USE_BASE
|
||||
#ifndef RABIT_USE_MOCK
|
||||
AllreduceRobust manager;
|
||||
typedef AllreduceRobust Manager;
|
||||
#else
|
||||
AllreduceMock manager;
|
||||
typedef AllreduceMock Manager;
|
||||
#endif
|
||||
#else
|
||||
AllreduceBase manager;
|
||||
typedef AllreduceBase Manager;
|
||||
#endif
|
||||
|
||||
/*! \brief entry to to easily hold returning information */
|
||||
struct ThreadLocalEntry {
|
||||
/*! \brief stores the current engine */
|
||||
std::unique_ptr<Manager> engine;
|
||||
/*! \brief whether init has been called */
|
||||
bool initialized;
|
||||
/*! \brief constructor */
|
||||
ThreadLocalEntry() : initialized(false) {}
|
||||
};
|
||||
|
||||
// define the threadlocal store.
|
||||
typedef ThreadLocalStore<ThreadLocalEntry> EngineThreadLocal;
|
||||
|
||||
/*! \brief intiialize the synchronization module */
|
||||
void Init(int argc, char *argv[]) {
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
char name[256], val[256];
|
||||
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
|
||||
manager.SetParam(name, val);
|
||||
}
|
||||
}
|
||||
manager.Init();
|
||||
ThreadLocalEntry* e = EngineThreadLocal::Get();
|
||||
utils::Check(e->engine.get() == nullptr,
|
||||
"rabit::Init is already called in this thread");
|
||||
e->initialized = true;
|
||||
e->engine.reset(new Manager());
|
||||
e->engine->Init(argc, argv);
|
||||
}
|
||||
|
||||
/*! \brief finalize syncrhonization module */
|
||||
void Finalize(void) {
|
||||
manager.Shutdown();
|
||||
void Finalize() {
|
||||
ThreadLocalEntry* e = EngineThreadLocal::Get();
|
||||
utils::Check(e->engine.get() != nullptr,
|
||||
"rabit::Finalize engine is not initialized or already been finalized.");
|
||||
e->engine->Shutdown();
|
||||
e->engine.reset(nullptr);
|
||||
}
|
||||
|
||||
/*! \brief singleton method to get engine */
|
||||
IEngine *GetEngine(void) {
|
||||
return &manager;
|
||||
IEngine *GetEngine() {
|
||||
// un-initialized default manager.
|
||||
static AllreduceBase default_manager;
|
||||
ThreadLocalEntry* e = EngineThreadLocal::Get();
|
||||
IEngine* ptr = e->engine.get();
|
||||
if (ptr == nullptr) {
|
||||
utils::Check(!e->initialized,
|
||||
"Doing rabit call after Finalize");
|
||||
return &default_manager;
|
||||
} else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
void Allreduce_(void *sendrecvbuf,
|
||||
size_t type_nbytes,
|
||||
@ -63,15 +93,18 @@ void Allreduce_(void *sendrecvbuf,
|
||||
ReduceHandle::ReduceHandle(void)
|
||||
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
|
||||
}
|
||||
|
||||
ReduceHandle::~ReduceHandle(void) {}
|
||||
|
||||
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
||||
return static_cast<int>(dtype.type_size);
|
||||
}
|
||||
|
||||
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
||||
utils::Assert(redfunc_ == NULL, "cannot initialize reduce handle twice");
|
||||
redfunc_ = redfunc;
|
||||
}
|
||||
|
||||
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||
size_t type_nbytes, size_t count,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
|
||||
87
src/thread_local.h
Normal file
87
src/thread_local.h
Normal file
@ -0,0 +1,87 @@
|
||||
/*!
|
||||
* Copyright (c) 2015 by Contributors
|
||||
* \file thread_local.h
|
||||
* \brief Common utility for thread local storage.
|
||||
*/
|
||||
#ifndef RABIT_THREAD_LOCAL_H_
|
||||
#define RABIT_THREAD_LOCAL_H_
|
||||
|
||||
#include "../include/dmlc/base.h"
|
||||
|
||||
#if DMLC_ENABLE_STD_THREAD
|
||||
#include <mutex>
|
||||
#endif
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace rabit {
|
||||
|
||||
// macro hanlding for threadlocal variables
|
||||
#ifdef __GNUC__
|
||||
#define MX_TREAD_LOCAL __thread
|
||||
#elif __STDC_VERSION__ >= 201112L
|
||||
#define MX_TREAD_LOCAL _Thread_local
|
||||
#elif defined(_MSC_VER)
|
||||
#define MX_TREAD_LOCAL __declspec(thread)
|
||||
#endif
|
||||
|
||||
#ifndef MX_TREAD_LOCAL
|
||||
#message("Warning: Threadlocal is not enabled");
|
||||
#endif
|
||||
|
||||
/*!
|
||||
* \brief A threadlocal store to store threadlocal variables.
|
||||
* Will return a thread local singleton of type T
|
||||
* \tparam T the type we like to store
|
||||
*/
|
||||
template<typename T>
|
||||
class ThreadLocalStore {
|
||||
public:
|
||||
/*! \return get a thread local singleton */
|
||||
static T* Get() {
|
||||
static MX_TREAD_LOCAL T* ptr = nullptr;
|
||||
if (ptr == nullptr) {
|
||||
ptr = new T();
|
||||
Singleton()->RegisterDelete(ptr);
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief constructor */
|
||||
ThreadLocalStore() {}
|
||||
/*! \brief destructor */
|
||||
~ThreadLocalStore() {
|
||||
for (size_t i = 0; i < data_.size(); ++i) {
|
||||
delete data_[i];
|
||||
}
|
||||
}
|
||||
/*! \return singleton of the store */
|
||||
static ThreadLocalStore<T> *Singleton() {
|
||||
static ThreadLocalStore<T> inst;
|
||||
return &inst;
|
||||
}
|
||||
/*!
|
||||
* \brief register str for internal deletion
|
||||
* \param str the string pointer
|
||||
*/
|
||||
void RegisterDelete(T *str) {
|
||||
#if DMLC_ENABLE_STD_THREAD
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
data_.push_back(str);
|
||||
lock.unlock();
|
||||
#else
|
||||
data_.push_back(str);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if DMLC_ENABLE_STD_THREAD
|
||||
/*! \brief internal mutex */
|
||||
std::mutex mutex_;
|
||||
#endif
|
||||
/*!\brief internal data */
|
||||
std::vector<T*> data_;
|
||||
};
|
||||
} // namespace rabit
|
||||
#endif // RABIT_THREAD_LOCAL_H_
|
||||
Loading…
x
Reference in New Issue
Block a user