From be50e7b63224b9fb7ff94ce34df9f8752ef83043 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 1 Mar 2016 20:12:51 -0800 Subject: [PATCH] Make rabit library thread local --- guide/Makefile | 18 ++++----- guide/basic.cc | 6 +-- guide/broadcast.cc | 2 +- guide/lazy_allreduce.cc | 11 +++--- src/allreduce_base.cc | 10 ++++- src/allreduce_base.h | 2 +- src/allreduce_robust.cc | 4 +- src/allreduce_robust.h | 2 +- src/engine.cc | 61 ++++++++++++++++++++++------- src/thread_local.h | 87 +++++++++++++++++++++++++++++++++++++++++ 10 files changed, 166 insertions(+), 37 deletions(-) create mode 100644 src/thread_local.h diff --git a/guide/Makefile b/guide/Makefile index 7213e1bf7..7b04231b8 100644 --- a/guide/Makefile +++ b/guide/Makefile @@ -2,25 +2,25 @@ export CC = gcc export CXX = g++ export MPICXX = mpicxx export LDFLAGS= -pthread -lm -L../lib -export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../include +export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -fopenmp -I../include .PHONY: clean all lib libmpi BIN = basic.rabit broadcast.rabit MOCKBIN= lazy_allreduce.mock all: $(BIN) -basic.rabit: basic.cc lib -broadcast.rabit: broadcast.cc lib -lazy_allreduce.mock: lazy_allreduce.cc lib +basic.rabit: basic.cc lib ../lib/librabit.a +broadcast.rabit: broadcast.cc lib ../lib/librabit.a +lazy_allreduce.mock: lazy_allreduce.cc lib ../lib/librabit.a -$(BIN) : - $(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit +$(BIN) : + $(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS) -$(MOCKBIN) : +$(MOCKBIN) : $(CXX) $(CFLAGS) -std=c++11 -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit_mock -$(OBJ) : +$(OBJ) : $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) ) clean: - $(RM) $(OBJ) $(BIN) $(MOCKBIN) *~ ../src/*~ \ No newline at end of file + $(RM) $(OBJ) $(BIN) $(MOCKBIN) *~ ../src/*~ diff --git a/guide/basic.cc b/guide/basic.cc index a9a729170..d08397b54 100644 --- a/guide/basic.cc +++ b/guide/basic.cc @@ -8,7 +8,7 @@ #define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_DEPRECATE #include -#include +#include using namespace rabit; int main(int argc, char *argv[]) { int N = 3; @@ -19,7 +19,7 @@ int main(int argc, char *argv[]) { rabit::Init(argc, argv); for (int i = 0; i < N; ++i) { a[i] = rabit::GetRank() + i; - } + } printf("@node[%d] before-allreduce: a={%d, %d, %d}\n", rabit::GetRank(), a[0], a[1], a[2]); // allreduce take max of each elements in all processes @@ -29,7 +29,7 @@ int main(int argc, char *argv[]) { // second allreduce that sums everything up Allreduce(&a[0], N); printf("@node[%d] after-allreduce-sum: a={%d, %d, %d}\n", - rabit::GetRank(), a[0], a[1], a[2]); + rabit::GetRank(), a[0], a[1], a[2]); rabit::Finalize(); return 0; } diff --git a/guide/broadcast.cc b/guide/broadcast.cc index 83dbe67fe..9e360d8de 100644 --- a/guide/broadcast.cc +++ b/guide/broadcast.cc @@ -1,4 +1,4 @@ -#include +#include using namespace rabit; const int N = 3; int main(int argc, char *argv[]) { diff --git a/guide/lazy_allreduce.cc b/guide/lazy_allreduce.cc index b54776ecc..b4b816fa0 100644 --- a/guide/lazy_allreduce.cc +++ b/guide/lazy_allreduce.cc @@ -5,7 +5,8 @@ * * \author Tianqi Chen */ -#include +#include + using namespace rabit; const int N = 3; int main(int argc, char *argv[]) { @@ -16,18 +17,18 @@ int main(int argc, char *argv[]) { printf("@node[%d] run prepare function\n", rabit::GetRank()); for (int i = 0; i < N; ++i) { a[i] = rabit::GetRank() + i; - } + } }; printf("@node[%d] before-allreduce: a={%d, %d, %d}\n", rabit::GetRank(), a[0], a[1], a[2]); // allreduce take max of each elements in all processes - Allreduce(&a[0], N, prepare); + Allreduce(&a[0], N, prepare); printf("@node[%d] after-allreduce-sum: a={%d, %d, %d}\n", - rabit::GetRank(), a[0], a[1], a[2]); + rabit::GetRank(), a[0], a[1], a[2]); // rum second allreduce Allreduce(&a[0], N); printf("@node[%d] after-allreduce-max: a={%d, %d, %d}\n", - rabit::GetRank(), a[0], a[1], a[2]); + rabit::GetRank(), a[0], a[1], a[2]); rabit::Finalize(); return 0; } diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 2600fc83b..862187bc1 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -51,7 +51,7 @@ AllreduceBase::AllreduceBase(void) { } // initialization function -void AllreduceBase::Init(void) { +void AllreduceBase::Init(int argc, char* argv[]) { // setup from enviroment variables // handler to get variables from env for (size_t i = 0; i < env_vars.size(); ++i) { @@ -60,6 +60,14 @@ void AllreduceBase::Init(void) { this->SetParam(env_vars[i].c_str(), value); } } + // pass in arguments override env variable. + for (int i = 0; i < argc; ++i) { + char name[256], val[256]; + if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) { + this->SetParam(name, val); + } + } + { // handling for hadoop const char *task_id = getenv("mapred_tip_id"); diff --git a/src/allreduce_base.h b/src/allreduce_base.h index 9a2cb3fb9..4194beb13 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -38,7 +38,7 @@ class AllreduceBase : public IEngine { AllreduceBase(void); virtual ~AllreduceBase(void) {} // initialize the manager - virtual void Init(void); + virtual void Init(int argc, char* argv[]); // shutdown the engine virtual void Shutdown(void); /*! diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index 3fd76782a..c89b69542 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -31,8 +31,8 @@ AllreduceRobust::AllreduceRobust(void) { env_vars.push_back("rabit_global_replica"); env_vars.push_back("rabit_local_replica"); } -void AllreduceRobust::Init(void) { - AllreduceBase::Init(); +void AllreduceRobust::Init(int argc, char* argv[]) { + AllreduceBase::Init(argc, argv); result_buffer_round = std::max(world_size / num_global_replica, 1); } /*! \brief shutdown the engine */ diff --git a/src/allreduce_robust.h b/src/allreduce_robust.h index 46e9f69c4..c8860822d 100644 --- a/src/allreduce_robust.h +++ b/src/allreduce_robust.h @@ -24,7 +24,7 @@ class AllreduceRobust : public AllreduceBase { AllreduceRobust(void); virtual ~AllreduceRobust(void) {} // initialize the manager - virtual void Init(void); + virtual void Init(int argc, char* argv[]); /*! \brief shutdown the engine */ virtual void Shutdown(void); /*! diff --git a/src/engine.cc b/src/engine.cc index 296775d85..c958932bd 100644 --- a/src/engine.cc +++ b/src/engine.cc @@ -10,42 +10,72 @@ #define _CRT_SECURE_NO_DEPRECATE #define NOMINMAX +#include #include "../include/rabit/internal/engine.h" #include "./allreduce_base.h" #include "./allreduce_robust.h" +#include "./thread_local.h" namespace rabit { namespace engine { // singleton sync manager #ifndef RABIT_USE_BASE #ifndef RABIT_USE_MOCK -AllreduceRobust manager; +typedef AllreduceRobust Manager; #else -AllreduceMock manager; +typedef AllreduceMock Manager; #endif #else -AllreduceBase manager; +typedef AllreduceBase Manager; #endif +/*! \brief entry to to easily hold returning information */ +struct ThreadLocalEntry { + /*! \brief stores the current engine */ + std::unique_ptr engine; + /*! \brief whether init has been called */ + bool initialized; + /*! \brief constructor */ + ThreadLocalEntry() : initialized(false) {} +}; + +// define the threadlocal store. +typedef ThreadLocalStore EngineThreadLocal; + /*! \brief intiialize the synchronization module */ void Init(int argc, char *argv[]) { - for (int i = 1; i < argc; ++i) { - char name[256], val[256]; - if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) { - manager.SetParam(name, val); - } - } - manager.Init(); + ThreadLocalEntry* e = EngineThreadLocal::Get(); + utils::Check(e->engine.get() == nullptr, + "rabit::Init is already called in this thread"); + e->initialized = true; + e->engine.reset(new Manager()); + e->engine->Init(argc, argv); } /*! \brief finalize syncrhonization module */ -void Finalize(void) { - manager.Shutdown(); +void Finalize() { + ThreadLocalEntry* e = EngineThreadLocal::Get(); + utils::Check(e->engine.get() != nullptr, + "rabit::Finalize engine is not initialized or already been finalized."); + e->engine->Shutdown(); + e->engine.reset(nullptr); } + /*! \brief singleton method to get engine */ -IEngine *GetEngine(void) { - return &manager; +IEngine *GetEngine() { + // un-initialized default manager. + static AllreduceBase default_manager; + ThreadLocalEntry* e = EngineThreadLocal::Get(); + IEngine* ptr = e->engine.get(); + if (ptr == nullptr) { + utils::Check(!e->initialized, + "Doing rabit call after Finalize"); + return &default_manager; + } else { + return ptr; + } } + // perform in-place allreduce, on sendrecvbuf void Allreduce_(void *sendrecvbuf, size_t type_nbytes, @@ -63,15 +93,18 @@ void Allreduce_(void *sendrecvbuf, ReduceHandle::ReduceHandle(void) : handle_(NULL), redfunc_(NULL), htype_(NULL) { } + ReduceHandle::~ReduceHandle(void) {} int ReduceHandle::TypeSize(const MPI::Datatype &dtype) { return static_cast(dtype.type_size); } + void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) { utils::Assert(redfunc_ == NULL, "cannot initialize reduce handle twice"); redfunc_ = redfunc; } + void ReduceHandle::Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count, IEngine::PreprocFunction prepare_fun, diff --git a/src/thread_local.h b/src/thread_local.h new file mode 100644 index 000000000..bd504b0e8 --- /dev/null +++ b/src/thread_local.h @@ -0,0 +1,87 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file thread_local.h + * \brief Common utility for thread local storage. + */ +#ifndef RABIT_THREAD_LOCAL_H_ +#define RABIT_THREAD_LOCAL_H_ + +#include "../include/dmlc/base.h" + +#if DMLC_ENABLE_STD_THREAD +#include +#endif + +#include +#include + +namespace rabit { + +// macro hanlding for threadlocal variables +#ifdef __GNUC__ + #define MX_TREAD_LOCAL __thread +#elif __STDC_VERSION__ >= 201112L + #define MX_TREAD_LOCAL _Thread_local +#elif defined(_MSC_VER) + #define MX_TREAD_LOCAL __declspec(thread) +#endif + +#ifndef MX_TREAD_LOCAL +#message("Warning: Threadlocal is not enabled"); +#endif + +/*! + * \brief A threadlocal store to store threadlocal variables. + * Will return a thread local singleton of type T + * \tparam T the type we like to store + */ +template +class ThreadLocalStore { + public: + /*! \return get a thread local singleton */ + static T* Get() { + static MX_TREAD_LOCAL T* ptr = nullptr; + if (ptr == nullptr) { + ptr = new T(); + Singleton()->RegisterDelete(ptr); + } + return ptr; + } + + private: + /*! \brief constructor */ + ThreadLocalStore() {} + /*! \brief destructor */ + ~ThreadLocalStore() { + for (size_t i = 0; i < data_.size(); ++i) { + delete data_[i]; + } + } + /*! \return singleton of the store */ + static ThreadLocalStore *Singleton() { + static ThreadLocalStore inst; + return &inst; + } + /*! + * \brief register str for internal deletion + * \param str the string pointer + */ + void RegisterDelete(T *str) { +#if DMLC_ENABLE_STD_THREAD + std::unique_lock lock(mutex_); + data_.push_back(str); + lock.unlock(); +#else + data_.push_back(str); +#endif + } + +#if DMLC_ENABLE_STD_THREAD + /*! \brief internal mutex */ + std::mutex mutex_; +#endif + /*!\brief internal data */ + std::vector data_; +}; +} // namespace rabit +#endif // RABIT_THREAD_LOCAL_H_