Make rabit library thread local

This commit is contained in:
tqchen 2016-03-01 20:12:51 -08:00
parent aeb4008606
commit be50e7b632
10 changed files with 166 additions and 37 deletions

View File

@ -2,19 +2,19 @@ export CC = gcc
export CXX = g++ export CXX = g++
export MPICXX = mpicxx export MPICXX = mpicxx
export LDFLAGS= -pthread -lm -L../lib export LDFLAGS= -pthread -lm -L../lib
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../include export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -fopenmp -I../include
.PHONY: clean all lib libmpi .PHONY: clean all lib libmpi
BIN = basic.rabit broadcast.rabit BIN = basic.rabit broadcast.rabit
MOCKBIN= lazy_allreduce.mock MOCKBIN= lazy_allreduce.mock
all: $(BIN) all: $(BIN)
basic.rabit: basic.cc lib basic.rabit: basic.cc lib ../lib/librabit.a
broadcast.rabit: broadcast.cc lib broadcast.rabit: broadcast.cc lib ../lib/librabit.a
lazy_allreduce.mock: lazy_allreduce.cc lib lazy_allreduce.mock: lazy_allreduce.cc lib ../lib/librabit.a
$(BIN) : $(BIN) :
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit $(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS)
$(MOCKBIN) : $(MOCKBIN) :
$(CXX) $(CFLAGS) -std=c++11 -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit_mock $(CXX) $(CFLAGS) -std=c++11 -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit_mock

View File

@ -8,7 +8,7 @@
#define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE #define _CRT_SECURE_NO_DEPRECATE
#include <vector> #include <vector>
#include <rabit.h> #include <rabit/rabit.h>
using namespace rabit; using namespace rabit;
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
int N = 3; int N = 3;

View File

@ -1,4 +1,4 @@
#include <rabit.h> #include <rabit/rabit.h>
using namespace rabit; using namespace rabit;
const int N = 3; const int N = 3;
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {

View File

@ -5,7 +5,8 @@
* *
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#include <rabit.h> #include <rabit/rabit.h>
using namespace rabit; using namespace rabit;
const int N = 3; const int N = 3;
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {

View File

@ -51,7 +51,7 @@ AllreduceBase::AllreduceBase(void) {
} }
// initialization function // initialization function
void AllreduceBase::Init(void) { void AllreduceBase::Init(int argc, char* argv[]) {
// setup from enviroment variables // setup from enviroment variables
// handler to get variables from env // handler to get variables from env
for (size_t i = 0; i < env_vars.size(); ++i) { for (size_t i = 0; i < env_vars.size(); ++i) {
@ -60,6 +60,14 @@ void AllreduceBase::Init(void) {
this->SetParam(env_vars[i].c_str(), value); this->SetParam(env_vars[i].c_str(), value);
} }
} }
// pass in arguments override env variable.
for (int i = 0; i < argc; ++i) {
char name[256], val[256];
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
this->SetParam(name, val);
}
}
{ {
// handling for hadoop // handling for hadoop
const char *task_id = getenv("mapred_tip_id"); const char *task_id = getenv("mapred_tip_id");

View File

@ -38,7 +38,7 @@ class AllreduceBase : public IEngine {
AllreduceBase(void); AllreduceBase(void);
virtual ~AllreduceBase(void) {} virtual ~AllreduceBase(void) {}
// initialize the manager // initialize the manager
virtual void Init(void); virtual void Init(int argc, char* argv[]);
// shutdown the engine // shutdown the engine
virtual void Shutdown(void); virtual void Shutdown(void);
/*! /*!

View File

@ -31,8 +31,8 @@ AllreduceRobust::AllreduceRobust(void) {
env_vars.push_back("rabit_global_replica"); env_vars.push_back("rabit_global_replica");
env_vars.push_back("rabit_local_replica"); env_vars.push_back("rabit_local_replica");
} }
void AllreduceRobust::Init(void) { void AllreduceRobust::Init(int argc, char* argv[]) {
AllreduceBase::Init(); AllreduceBase::Init(argc, argv);
result_buffer_round = std::max(world_size / num_global_replica, 1); result_buffer_round = std::max(world_size / num_global_replica, 1);
} }
/*! \brief shutdown the engine */ /*! \brief shutdown the engine */

View File

@ -24,7 +24,7 @@ class AllreduceRobust : public AllreduceBase {
AllreduceRobust(void); AllreduceRobust(void);
virtual ~AllreduceRobust(void) {} virtual ~AllreduceRobust(void) {}
// initialize the manager // initialize the manager
virtual void Init(void); virtual void Init(int argc, char* argv[]);
/*! \brief shutdown the engine */ /*! \brief shutdown the engine */
virtual void Shutdown(void); virtual void Shutdown(void);
/*! /*!

View File

@ -10,42 +10,72 @@
#define _CRT_SECURE_NO_DEPRECATE #define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX #define NOMINMAX
#include <memory>
#include "../include/rabit/internal/engine.h" #include "../include/rabit/internal/engine.h"
#include "./allreduce_base.h" #include "./allreduce_base.h"
#include "./allreduce_robust.h" #include "./allreduce_robust.h"
#include "./thread_local.h"
namespace rabit { namespace rabit {
namespace engine { namespace engine {
// singleton sync manager // singleton sync manager
#ifndef RABIT_USE_BASE #ifndef RABIT_USE_BASE
#ifndef RABIT_USE_MOCK #ifndef RABIT_USE_MOCK
AllreduceRobust manager; typedef AllreduceRobust Manager;
#else #else
AllreduceMock manager; typedef AllreduceMock Manager;
#endif #endif
#else #else
AllreduceBase manager; typedef AllreduceBase Manager;
#endif #endif
/*! \brief entry to to easily hold returning information */
struct ThreadLocalEntry {
/*! \brief stores the current engine */
std::unique_ptr<Manager> engine;
/*! \brief whether init has been called */
bool initialized;
/*! \brief constructor */
ThreadLocalEntry() : initialized(false) {}
};
// define the threadlocal store.
typedef ThreadLocalStore<ThreadLocalEntry> EngineThreadLocal;
/*! \brief intiialize the synchronization module */ /*! \brief intiialize the synchronization module */
void Init(int argc, char *argv[]) { void Init(int argc, char *argv[]) {
for (int i = 1; i < argc; ++i) { ThreadLocalEntry* e = EngineThreadLocal::Get();
char name[256], val[256]; utils::Check(e->engine.get() == nullptr,
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) { "rabit::Init is already called in this thread");
manager.SetParam(name, val); e->initialized = true;
} e->engine.reset(new Manager());
} e->engine->Init(argc, argv);
manager.Init();
} }
/*! \brief finalize syncrhonization module */ /*! \brief finalize syncrhonization module */
void Finalize(void) { void Finalize() {
manager.Shutdown(); ThreadLocalEntry* e = EngineThreadLocal::Get();
utils::Check(e->engine.get() != nullptr,
"rabit::Finalize engine is not initialized or already been finalized.");
e->engine->Shutdown();
e->engine.reset(nullptr);
} }
/*! \brief singleton method to get engine */ /*! \brief singleton method to get engine */
IEngine *GetEngine(void) { IEngine *GetEngine() {
return &manager; // un-initialized default manager.
static AllreduceBase default_manager;
ThreadLocalEntry* e = EngineThreadLocal::Get();
IEngine* ptr = e->engine.get();
if (ptr == nullptr) {
utils::Check(!e->initialized,
"Doing rabit call after Finalize");
return &default_manager;
} else {
return ptr;
}
} }
// perform in-place allreduce, on sendrecvbuf // perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf, void Allreduce_(void *sendrecvbuf,
size_t type_nbytes, size_t type_nbytes,
@ -63,15 +93,18 @@ void Allreduce_(void *sendrecvbuf,
ReduceHandle::ReduceHandle(void) ReduceHandle::ReduceHandle(void)
: handle_(NULL), redfunc_(NULL), htype_(NULL) { : handle_(NULL), redfunc_(NULL), htype_(NULL) {
} }
ReduceHandle::~ReduceHandle(void) {} ReduceHandle::~ReduceHandle(void) {}
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) { int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return static_cast<int>(dtype.type_size); return static_cast<int>(dtype.type_size);
} }
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) { void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
utils::Assert(redfunc_ == NULL, "cannot initialize reduce handle twice"); utils::Assert(redfunc_ == NULL, "cannot initialize reduce handle twice");
redfunc_ = redfunc; redfunc_ = redfunc;
} }
void ReduceHandle::Allreduce(void *sendrecvbuf, void ReduceHandle::Allreduce(void *sendrecvbuf,
size_t type_nbytes, size_t count, size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun, IEngine::PreprocFunction prepare_fun,

87
src/thread_local.h Normal file
View File

@ -0,0 +1,87 @@
/*!
* Copyright (c) 2015 by Contributors
* \file thread_local.h
* \brief Common utility for thread local storage.
*/
#ifndef RABIT_THREAD_LOCAL_H_
#define RABIT_THREAD_LOCAL_H_
#include "../include/dmlc/base.h"
#if DMLC_ENABLE_STD_THREAD
#include <mutex>
#endif
#include <memory>
#include <vector>
namespace rabit {
// macro hanlding for threadlocal variables
#ifdef __GNUC__
#define MX_TREAD_LOCAL __thread
#elif __STDC_VERSION__ >= 201112L
#define MX_TREAD_LOCAL _Thread_local
#elif defined(_MSC_VER)
#define MX_TREAD_LOCAL __declspec(thread)
#endif
#ifndef MX_TREAD_LOCAL
#message("Warning: Threadlocal is not enabled");
#endif
/*!
* \brief A threadlocal store to store threadlocal variables.
* Will return a thread local singleton of type T
* \tparam T the type we like to store
*/
template<typename T>
class ThreadLocalStore {
public:
/*! \return get a thread local singleton */
static T* Get() {
static MX_TREAD_LOCAL T* ptr = nullptr;
if (ptr == nullptr) {
ptr = new T();
Singleton()->RegisterDelete(ptr);
}
return ptr;
}
private:
/*! \brief constructor */
ThreadLocalStore() {}
/*! \brief destructor */
~ThreadLocalStore() {
for (size_t i = 0; i < data_.size(); ++i) {
delete data_[i];
}
}
/*! \return singleton of the store */
static ThreadLocalStore<T> *Singleton() {
static ThreadLocalStore<T> inst;
return &inst;
}
/*!
* \brief register str for internal deletion
* \param str the string pointer
*/
void RegisterDelete(T *str) {
#if DMLC_ENABLE_STD_THREAD
std::unique_lock<std::mutex> lock(mutex_);
data_.push_back(str);
lock.unlock();
#else
data_.push_back(str);
#endif
}
#if DMLC_ENABLE_STD_THREAD
/*! \brief internal mutex */
std::mutex mutex_;
#endif
/*!\brief internal data */
std::vector<T*> data_;
};
} // namespace rabit
#endif // RABIT_THREAD_LOCAL_H_