Make rabit library thread local
This commit is contained in:
parent
aeb4008606
commit
be50e7b632
@ -2,19 +2,19 @@ export CC = gcc
|
|||||||
export CXX = g++
|
export CXX = g++
|
||||||
export MPICXX = mpicxx
|
export MPICXX = mpicxx
|
||||||
export LDFLAGS= -pthread -lm -L../lib
|
export LDFLAGS= -pthread -lm -L../lib
|
||||||
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../include
|
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -fopenmp -I../include
|
||||||
|
|
||||||
.PHONY: clean all lib libmpi
|
.PHONY: clean all lib libmpi
|
||||||
BIN = basic.rabit broadcast.rabit
|
BIN = basic.rabit broadcast.rabit
|
||||||
MOCKBIN= lazy_allreduce.mock
|
MOCKBIN= lazy_allreduce.mock
|
||||||
|
|
||||||
all: $(BIN)
|
all: $(BIN)
|
||||||
basic.rabit: basic.cc lib
|
basic.rabit: basic.cc lib ../lib/librabit.a
|
||||||
broadcast.rabit: broadcast.cc lib
|
broadcast.rabit: broadcast.cc lib ../lib/librabit.a
|
||||||
lazy_allreduce.mock: lazy_allreduce.cc lib
|
lazy_allreduce.mock: lazy_allreduce.cc lib ../lib/librabit.a
|
||||||
|
|
||||||
$(BIN) :
|
$(BIN) :
|
||||||
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit
|
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS)
|
||||||
|
|
||||||
$(MOCKBIN) :
|
$(MOCKBIN) :
|
||||||
$(CXX) $(CFLAGS) -std=c++11 -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit_mock
|
$(CXX) $(CFLAGS) -std=c++11 -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit_mock
|
||||||
|
|||||||
@ -8,7 +8,7 @@
|
|||||||
#define _CRT_SECURE_NO_WARNINGS
|
#define _CRT_SECURE_NO_WARNINGS
|
||||||
#define _CRT_SECURE_NO_DEPRECATE
|
#define _CRT_SECURE_NO_DEPRECATE
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <rabit.h>
|
#include <rabit/rabit.h>
|
||||||
using namespace rabit;
|
using namespace rabit;
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
int N = 3;
|
int N = 3;
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
#include <rabit.h>
|
#include <rabit/rabit.h>
|
||||||
using namespace rabit;
|
using namespace rabit;
|
||||||
const int N = 3;
|
const int N = 3;
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
|
|||||||
@ -5,7 +5,8 @@
|
|||||||
*
|
*
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#include <rabit.h>
|
#include <rabit/rabit.h>
|
||||||
|
|
||||||
using namespace rabit;
|
using namespace rabit;
|
||||||
const int N = 3;
|
const int N = 3;
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
|
|||||||
@ -51,7 +51,7 @@ AllreduceBase::AllreduceBase(void) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// initialization function
|
// initialization function
|
||||||
void AllreduceBase::Init(void) {
|
void AllreduceBase::Init(int argc, char* argv[]) {
|
||||||
// setup from enviroment variables
|
// setup from enviroment variables
|
||||||
// handler to get variables from env
|
// handler to get variables from env
|
||||||
for (size_t i = 0; i < env_vars.size(); ++i) {
|
for (size_t i = 0; i < env_vars.size(); ++i) {
|
||||||
@ -60,6 +60,14 @@ void AllreduceBase::Init(void) {
|
|||||||
this->SetParam(env_vars[i].c_str(), value);
|
this->SetParam(env_vars[i].c_str(), value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// pass in arguments override env variable.
|
||||||
|
for (int i = 0; i < argc; ++i) {
|
||||||
|
char name[256], val[256];
|
||||||
|
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
|
||||||
|
this->SetParam(name, val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
// handling for hadoop
|
// handling for hadoop
|
||||||
const char *task_id = getenv("mapred_tip_id");
|
const char *task_id = getenv("mapred_tip_id");
|
||||||
|
|||||||
@ -38,7 +38,7 @@ class AllreduceBase : public IEngine {
|
|||||||
AllreduceBase(void);
|
AllreduceBase(void);
|
||||||
virtual ~AllreduceBase(void) {}
|
virtual ~AllreduceBase(void) {}
|
||||||
// initialize the manager
|
// initialize the manager
|
||||||
virtual void Init(void);
|
virtual void Init(int argc, char* argv[]);
|
||||||
// shutdown the engine
|
// shutdown the engine
|
||||||
virtual void Shutdown(void);
|
virtual void Shutdown(void);
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -31,8 +31,8 @@ AllreduceRobust::AllreduceRobust(void) {
|
|||||||
env_vars.push_back("rabit_global_replica");
|
env_vars.push_back("rabit_global_replica");
|
||||||
env_vars.push_back("rabit_local_replica");
|
env_vars.push_back("rabit_local_replica");
|
||||||
}
|
}
|
||||||
void AllreduceRobust::Init(void) {
|
void AllreduceRobust::Init(int argc, char* argv[]) {
|
||||||
AllreduceBase::Init();
|
AllreduceBase::Init(argc, argv);
|
||||||
result_buffer_round = std::max(world_size / num_global_replica, 1);
|
result_buffer_round = std::max(world_size / num_global_replica, 1);
|
||||||
}
|
}
|
||||||
/*! \brief shutdown the engine */
|
/*! \brief shutdown the engine */
|
||||||
|
|||||||
@ -24,7 +24,7 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
AllreduceRobust(void);
|
AllreduceRobust(void);
|
||||||
virtual ~AllreduceRobust(void) {}
|
virtual ~AllreduceRobust(void) {}
|
||||||
// initialize the manager
|
// initialize the manager
|
||||||
virtual void Init(void);
|
virtual void Init(int argc, char* argv[]);
|
||||||
/*! \brief shutdown the engine */
|
/*! \brief shutdown the engine */
|
||||||
virtual void Shutdown(void);
|
virtual void Shutdown(void);
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -10,42 +10,72 @@
|
|||||||
#define _CRT_SECURE_NO_DEPRECATE
|
#define _CRT_SECURE_NO_DEPRECATE
|
||||||
#define NOMINMAX
|
#define NOMINMAX
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include "../include/rabit/internal/engine.h"
|
#include "../include/rabit/internal/engine.h"
|
||||||
#include "./allreduce_base.h"
|
#include "./allreduce_base.h"
|
||||||
#include "./allreduce_robust.h"
|
#include "./allreduce_robust.h"
|
||||||
|
#include "./thread_local.h"
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
// singleton sync manager
|
// singleton sync manager
|
||||||
#ifndef RABIT_USE_BASE
|
#ifndef RABIT_USE_BASE
|
||||||
#ifndef RABIT_USE_MOCK
|
#ifndef RABIT_USE_MOCK
|
||||||
AllreduceRobust manager;
|
typedef AllreduceRobust Manager;
|
||||||
#else
|
#else
|
||||||
AllreduceMock manager;
|
typedef AllreduceMock Manager;
|
||||||
#endif
|
#endif
|
||||||
#else
|
#else
|
||||||
AllreduceBase manager;
|
typedef AllreduceBase Manager;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/*! \brief entry to to easily hold returning information */
|
||||||
|
struct ThreadLocalEntry {
|
||||||
|
/*! \brief stores the current engine */
|
||||||
|
std::unique_ptr<Manager> engine;
|
||||||
|
/*! \brief whether init has been called */
|
||||||
|
bool initialized;
|
||||||
|
/*! \brief constructor */
|
||||||
|
ThreadLocalEntry() : initialized(false) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// define the threadlocal store.
|
||||||
|
typedef ThreadLocalStore<ThreadLocalEntry> EngineThreadLocal;
|
||||||
|
|
||||||
/*! \brief intiialize the synchronization module */
|
/*! \brief intiialize the synchronization module */
|
||||||
void Init(int argc, char *argv[]) {
|
void Init(int argc, char *argv[]) {
|
||||||
for (int i = 1; i < argc; ++i) {
|
ThreadLocalEntry* e = EngineThreadLocal::Get();
|
||||||
char name[256], val[256];
|
utils::Check(e->engine.get() == nullptr,
|
||||||
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
|
"rabit::Init is already called in this thread");
|
||||||
manager.SetParam(name, val);
|
e->initialized = true;
|
||||||
}
|
e->engine.reset(new Manager());
|
||||||
}
|
e->engine->Init(argc, argv);
|
||||||
manager.Init();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief finalize syncrhonization module */
|
/*! \brief finalize syncrhonization module */
|
||||||
void Finalize(void) {
|
void Finalize() {
|
||||||
manager.Shutdown();
|
ThreadLocalEntry* e = EngineThreadLocal::Get();
|
||||||
|
utils::Check(e->engine.get() != nullptr,
|
||||||
|
"rabit::Finalize engine is not initialized or already been finalized.");
|
||||||
|
e->engine->Shutdown();
|
||||||
|
e->engine.reset(nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief singleton method to get engine */
|
/*! \brief singleton method to get engine */
|
||||||
IEngine *GetEngine(void) {
|
IEngine *GetEngine() {
|
||||||
return &manager;
|
// un-initialized default manager.
|
||||||
|
static AllreduceBase default_manager;
|
||||||
|
ThreadLocalEntry* e = EngineThreadLocal::Get();
|
||||||
|
IEngine* ptr = e->engine.get();
|
||||||
|
if (ptr == nullptr) {
|
||||||
|
utils::Check(!e->initialized,
|
||||||
|
"Doing rabit call after Finalize");
|
||||||
|
return &default_manager;
|
||||||
|
} else {
|
||||||
|
return ptr;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// perform in-place allreduce, on sendrecvbuf
|
// perform in-place allreduce, on sendrecvbuf
|
||||||
void Allreduce_(void *sendrecvbuf,
|
void Allreduce_(void *sendrecvbuf,
|
||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
@ -63,15 +93,18 @@ void Allreduce_(void *sendrecvbuf,
|
|||||||
ReduceHandle::ReduceHandle(void)
|
ReduceHandle::ReduceHandle(void)
|
||||||
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
|
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ReduceHandle::~ReduceHandle(void) {}
|
ReduceHandle::~ReduceHandle(void) {}
|
||||||
|
|
||||||
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
||||||
return static_cast<int>(dtype.type_size);
|
return static_cast<int>(dtype.type_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
||||||
utils::Assert(redfunc_ == NULL, "cannot initialize reduce handle twice");
|
utils::Assert(redfunc_ == NULL, "cannot initialize reduce handle twice");
|
||||||
redfunc_ = redfunc;
|
redfunc_ = redfunc;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||||
size_t type_nbytes, size_t count,
|
size_t type_nbytes, size_t count,
|
||||||
IEngine::PreprocFunction prepare_fun,
|
IEngine::PreprocFunction prepare_fun,
|
||||||
|
|||||||
87
src/thread_local.h
Normal file
87
src/thread_local.h
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright (c) 2015 by Contributors
|
||||||
|
* \file thread_local.h
|
||||||
|
* \brief Common utility for thread local storage.
|
||||||
|
*/
|
||||||
|
#ifndef RABIT_THREAD_LOCAL_H_
|
||||||
|
#define RABIT_THREAD_LOCAL_H_
|
||||||
|
|
||||||
|
#include "../include/dmlc/base.h"
|
||||||
|
|
||||||
|
#if DMLC_ENABLE_STD_THREAD
|
||||||
|
#include <mutex>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace rabit {
|
||||||
|
|
||||||
|
// macro hanlding for threadlocal variables
|
||||||
|
#ifdef __GNUC__
|
||||||
|
#define MX_TREAD_LOCAL __thread
|
||||||
|
#elif __STDC_VERSION__ >= 201112L
|
||||||
|
#define MX_TREAD_LOCAL _Thread_local
|
||||||
|
#elif defined(_MSC_VER)
|
||||||
|
#define MX_TREAD_LOCAL __declspec(thread)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef MX_TREAD_LOCAL
|
||||||
|
#message("Warning: Threadlocal is not enabled");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief A threadlocal store to store threadlocal variables.
|
||||||
|
* Will return a thread local singleton of type T
|
||||||
|
* \tparam T the type we like to store
|
||||||
|
*/
|
||||||
|
template<typename T>
|
||||||
|
class ThreadLocalStore {
|
||||||
|
public:
|
||||||
|
/*! \return get a thread local singleton */
|
||||||
|
static T* Get() {
|
||||||
|
static MX_TREAD_LOCAL T* ptr = nullptr;
|
||||||
|
if (ptr == nullptr) {
|
||||||
|
ptr = new T();
|
||||||
|
Singleton()->RegisterDelete(ptr);
|
||||||
|
}
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/*! \brief constructor */
|
||||||
|
ThreadLocalStore() {}
|
||||||
|
/*! \brief destructor */
|
||||||
|
~ThreadLocalStore() {
|
||||||
|
for (size_t i = 0; i < data_.size(); ++i) {
|
||||||
|
delete data_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*! \return singleton of the store */
|
||||||
|
static ThreadLocalStore<T> *Singleton() {
|
||||||
|
static ThreadLocalStore<T> inst;
|
||||||
|
return &inst;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief register str for internal deletion
|
||||||
|
* \param str the string pointer
|
||||||
|
*/
|
||||||
|
void RegisterDelete(T *str) {
|
||||||
|
#if DMLC_ENABLE_STD_THREAD
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
|
data_.push_back(str);
|
||||||
|
lock.unlock();
|
||||||
|
#else
|
||||||
|
data_.push_back(str);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
#if DMLC_ENABLE_STD_THREAD
|
||||||
|
/*! \brief internal mutex */
|
||||||
|
std::mutex mutex_;
|
||||||
|
#endif
|
||||||
|
/*!\brief internal data */
|
||||||
|
std::vector<T*> data_;
|
||||||
|
};
|
||||||
|
} // namespace rabit
|
||||||
|
#endif // RABIT_THREAD_LOCAL_H_
|
||||||
Loading…
x
Reference in New Issue
Block a user