Make rabit library thread local

This commit is contained in:
tqchen 2016-03-01 20:12:51 -08:00
parent aeb4008606
commit be50e7b632
10 changed files with 166 additions and 37 deletions

View File

@ -2,25 +2,25 @@ export CC = gcc
export CXX = g++
export MPICXX = mpicxx
export LDFLAGS= -pthread -lm -L../lib
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../include
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -fopenmp -I../include
.PHONY: clean all lib libmpi
BIN = basic.rabit broadcast.rabit
MOCKBIN= lazy_allreduce.mock
all: $(BIN)
basic.rabit: basic.cc lib
broadcast.rabit: broadcast.cc lib
lazy_allreduce.mock: lazy_allreduce.cc lib
basic.rabit: basic.cc lib ../lib/librabit.a
broadcast.rabit: broadcast.cc lib ../lib/librabit.a
lazy_allreduce.mock: lazy_allreduce.cc lib ../lib/librabit.a
$(BIN) :
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit
$(BIN) :
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS)
$(MOCKBIN) :
$(MOCKBIN) :
$(CXX) $(CFLAGS) -std=c++11 -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit_mock
$(OBJ) :
$(OBJ) :
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
clean:
$(RM) $(OBJ) $(BIN) $(MOCKBIN) *~ ../src/*~
$(RM) $(OBJ) $(BIN) $(MOCKBIN) *~ ../src/*~

View File

@ -8,7 +8,7 @@
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#include <vector>
#include <rabit.h>
#include <rabit/rabit.h>
using namespace rabit;
int main(int argc, char *argv[]) {
int N = 3;
@ -19,7 +19,7 @@ int main(int argc, char *argv[]) {
rabit::Init(argc, argv);
for (int i = 0; i < N; ++i) {
a[i] = rabit::GetRank() + i;
}
}
printf("@node[%d] before-allreduce: a={%d, %d, %d}\n",
rabit::GetRank(), a[0], a[1], a[2]);
// allreduce take max of each elements in all processes
@ -29,7 +29,7 @@ int main(int argc, char *argv[]) {
// second allreduce that sums everything up
Allreduce<op::Sum>(&a[0], N);
printf("@node[%d] after-allreduce-sum: a={%d, %d, %d}\n",
rabit::GetRank(), a[0], a[1], a[2]);
rabit::GetRank(), a[0], a[1], a[2]);
rabit::Finalize();
return 0;
}

View File

@ -1,4 +1,4 @@
#include <rabit.h>
#include <rabit/rabit.h>
using namespace rabit;
const int N = 3;
int main(int argc, char *argv[]) {

View File

@ -5,7 +5,8 @@
*
* \author Tianqi Chen
*/
#include <rabit.h>
#include <rabit/rabit.h>
using namespace rabit;
const int N = 3;
int main(int argc, char *argv[]) {
@ -16,18 +17,18 @@ int main(int argc, char *argv[]) {
printf("@node[%d] run prepare function\n", rabit::GetRank());
for (int i = 0; i < N; ++i) {
a[i] = rabit::GetRank() + i;
}
}
};
printf("@node[%d] before-allreduce: a={%d, %d, %d}\n",
rabit::GetRank(), a[0], a[1], a[2]);
// allreduce take max of each elements in all processes
Allreduce<op::Max>(&a[0], N, prepare);
Allreduce<op::Max>(&a[0], N, prepare);
printf("@node[%d] after-allreduce-sum: a={%d, %d, %d}\n",
rabit::GetRank(), a[0], a[1], a[2]);
rabit::GetRank(), a[0], a[1], a[2]);
// rum second allreduce
Allreduce<op::Sum>(&a[0], N);
printf("@node[%d] after-allreduce-max: a={%d, %d, %d}\n",
rabit::GetRank(), a[0], a[1], a[2]);
rabit::GetRank(), a[0], a[1], a[2]);
rabit::Finalize();
return 0;
}

View File

@ -51,7 +51,7 @@ AllreduceBase::AllreduceBase(void) {
}
// initialization function
void AllreduceBase::Init(void) {
void AllreduceBase::Init(int argc, char* argv[]) {
// setup from enviroment variables
// handler to get variables from env
for (size_t i = 0; i < env_vars.size(); ++i) {
@ -60,6 +60,14 @@ void AllreduceBase::Init(void) {
this->SetParam(env_vars[i].c_str(), value);
}
}
// pass in arguments override env variable.
for (int i = 0; i < argc; ++i) {
char name[256], val[256];
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
this->SetParam(name, val);
}
}
{
// handling for hadoop
const char *task_id = getenv("mapred_tip_id");

View File

@ -38,7 +38,7 @@ class AllreduceBase : public IEngine {
AllreduceBase(void);
virtual ~AllreduceBase(void) {}
// initialize the manager
virtual void Init(void);
virtual void Init(int argc, char* argv[]);
// shutdown the engine
virtual void Shutdown(void);
/*!

View File

@ -31,8 +31,8 @@ AllreduceRobust::AllreduceRobust(void) {
env_vars.push_back("rabit_global_replica");
env_vars.push_back("rabit_local_replica");
}
void AllreduceRobust::Init(void) {
AllreduceBase::Init();
void AllreduceRobust::Init(int argc, char* argv[]) {
AllreduceBase::Init(argc, argv);
result_buffer_round = std::max(world_size / num_global_replica, 1);
}
/*! \brief shutdown the engine */

View File

@ -24,7 +24,7 @@ class AllreduceRobust : public AllreduceBase {
AllreduceRobust(void);
virtual ~AllreduceRobust(void) {}
// initialize the manager
virtual void Init(void);
virtual void Init(int argc, char* argv[]);
/*! \brief shutdown the engine */
virtual void Shutdown(void);
/*!

View File

@ -10,42 +10,72 @@
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include <memory>
#include "../include/rabit/internal/engine.h"
#include "./allreduce_base.h"
#include "./allreduce_robust.h"
#include "./thread_local.h"
namespace rabit {
namespace engine {
// singleton sync manager
#ifndef RABIT_USE_BASE
#ifndef RABIT_USE_MOCK
AllreduceRobust manager;
typedef AllreduceRobust Manager;
#else
AllreduceMock manager;
typedef AllreduceMock Manager;
#endif
#else
AllreduceBase manager;
typedef AllreduceBase Manager;
#endif
/*! \brief entry to to easily hold returning information */
struct ThreadLocalEntry {
/*! \brief stores the current engine */
std::unique_ptr<Manager> engine;
/*! \brief whether init has been called */
bool initialized;
/*! \brief constructor */
ThreadLocalEntry() : initialized(false) {}
};
// define the threadlocal store.
typedef ThreadLocalStore<ThreadLocalEntry> EngineThreadLocal;
/*! \brief intiialize the synchronization module */
void Init(int argc, char *argv[]) {
for (int i = 1; i < argc; ++i) {
char name[256], val[256];
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
manager.SetParam(name, val);
}
}
manager.Init();
ThreadLocalEntry* e = EngineThreadLocal::Get();
utils::Check(e->engine.get() == nullptr,
"rabit::Init is already called in this thread");
e->initialized = true;
e->engine.reset(new Manager());
e->engine->Init(argc, argv);
}
/*! \brief finalize syncrhonization module */
void Finalize(void) {
manager.Shutdown();
void Finalize() {
ThreadLocalEntry* e = EngineThreadLocal::Get();
utils::Check(e->engine.get() != nullptr,
"rabit::Finalize engine is not initialized or already been finalized.");
e->engine->Shutdown();
e->engine.reset(nullptr);
}
/*! \brief singleton method to get engine */
IEngine *GetEngine(void) {
return &manager;
IEngine *GetEngine() {
// un-initialized default manager.
static AllreduceBase default_manager;
ThreadLocalEntry* e = EngineThreadLocal::Get();
IEngine* ptr = e->engine.get();
if (ptr == nullptr) {
utils::Check(!e->initialized,
"Doing rabit call after Finalize");
return &default_manager;
} else {
return ptr;
}
}
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf,
size_t type_nbytes,
@ -63,15 +93,18 @@ void Allreduce_(void *sendrecvbuf,
ReduceHandle::ReduceHandle(void)
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
}
ReduceHandle::~ReduceHandle(void) {}
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return static_cast<int>(dtype.type_size);
}
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
utils::Assert(redfunc_ == NULL, "cannot initialize reduce handle twice");
redfunc_ = redfunc;
}
void ReduceHandle::Allreduce(void *sendrecvbuf,
size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun,

87
src/thread_local.h Normal file
View File

@ -0,0 +1,87 @@
/*!
* Copyright (c) 2015 by Contributors
* \file thread_local.h
* \brief Common utility for thread local storage.
*/
#ifndef RABIT_THREAD_LOCAL_H_
#define RABIT_THREAD_LOCAL_H_
#include "../include/dmlc/base.h"
#if DMLC_ENABLE_STD_THREAD
#include <mutex>
#endif
#include <memory>
#include <vector>
namespace rabit {
// macro hanlding for threadlocal variables
#ifdef __GNUC__
#define MX_TREAD_LOCAL __thread
#elif __STDC_VERSION__ >= 201112L
#define MX_TREAD_LOCAL _Thread_local
#elif defined(_MSC_VER)
#define MX_TREAD_LOCAL __declspec(thread)
#endif
#ifndef MX_TREAD_LOCAL
#message("Warning: Threadlocal is not enabled");
#endif
/*!
* \brief A threadlocal store to store threadlocal variables.
* Will return a thread local singleton of type T
* \tparam T the type we like to store
*/
template<typename T>
class ThreadLocalStore {
public:
/*! \return get a thread local singleton */
static T* Get() {
static MX_TREAD_LOCAL T* ptr = nullptr;
if (ptr == nullptr) {
ptr = new T();
Singleton()->RegisterDelete(ptr);
}
return ptr;
}
private:
/*! \brief constructor */
ThreadLocalStore() {}
/*! \brief destructor */
~ThreadLocalStore() {
for (size_t i = 0; i < data_.size(); ++i) {
delete data_[i];
}
}
/*! \return singleton of the store */
static ThreadLocalStore<T> *Singleton() {
static ThreadLocalStore<T> inst;
return &inst;
}
/*!
* \brief register str for internal deletion
* \param str the string pointer
*/
void RegisterDelete(T *str) {
#if DMLC_ENABLE_STD_THREAD
std::unique_lock<std::mutex> lock(mutex_);
data_.push_back(str);
lock.unlock();
#else
data_.push_back(str);
#endif
}
#if DMLC_ENABLE_STD_THREAD
/*! \brief internal mutex */
std::mutex mutex_;
#endif
/*!\brief internal data */
std::vector<T*> data_;
};
} // namespace rabit
#endif // RABIT_THREAD_LOCAL_H_