support bootstrap allreduce/broadcast (#98)
* support run rabit tests as xgboost subproject using xgboost/dmlc-core * support tracker config set/get * remove redudant printf * remove redudant printf * add c++0x declaration * log allreduce/broadcast caller, engine should track caller stack for investigation * tracker support binary config format * Revert "tracker support binary config format" This reverts commit 2a28e5e2b55c200cb621af8d19f17ab1bc62503b. * remove caller, prototype fetch allreduce/broadcast results from resbuf * store cached allreduce/broadcast seq_no to tracker * allow restore all caches from other nodes * try new rabit collective cache, todo: recv_link seems down * link up cache restore with main recovery * cleanup load cache state * update cache api * pass test.mk * have a working tests * try to unify check into actionsummary * more logging to debug distributed hist three method issue * update rabit interface to support caller signature matching * splite seq_counter from cur_cache_seq to different variables * still see issue with inf loop * support debug print caller as well as allreduce op * cleanup * remove get/set cache from model_recover, adding recover in loadcheckpoint * clarify rabit cache strategy, cache is set only by successful collective call involving all nodes with unique cache key. if all nodes call getcache at same time, we keep rabit run collective call. If some nodes call getcache while others not, we backfill cache from those nodes with most entries * revert caller logs * fix lint error * fix engine mpi signature * support getcache by ref * allow result buffer presiet to filestream * add loging * try fix checkpoint failure recovery case * use int64_t to avoid overflow caused seq fault * try avoid int overflow * try fix checkpoint failure recovery case * try avoid seqno overflow to negative by offseting specifial flag value adding cache seq no to checkpoint/load checkpoint/check point ack to avoid confusion from cache recovery * fix cache seq assert error * remove loging, handle edge case * add extensive log to checkpoint state with different seq no * fix lint errors * clean up comments before merge back to master * add logs to allreduce/broadcast/checkpoint * use unsinged int 32 and give seq no larger range * address remove allreduce dropseq code segment * using caller signature to filter bootstrapallreduces * remove get/set cache from empty * apply signature to reducer * apply signature to broadcast * add key to broadcat log * fix broadcast signature * fix default _line value for non linux system * adding comments, remove sleep(1) * fix osx build issue * try fix mpi * fix doc * fix engine_empty api * logging, adding more logs, restore immutable assertion * print unsinged int with ud * fix lint * rename seqtype to kSeq and KCache indicating it's usage apply kDiffSeq check to load_cache routine * comment allreduce/broadcast log * allow tests run on arm * enable flag to turn on / off cache * add log info alert if user choose to enable rabit bootstrap cache * add rabit_debug setting so user can use config to turn on * log flags when user turn on rabit_debug * force rabit restart if tracker assign -1 rank * use OPENMP to vecotrize reducer * address comment * Revert "address comment" This reverts commit 1dc61f33e7357dad8fa65528abeb81db92c5f9ed. * fix checkpoint size print 0 * per feedback, remove DISABLEOPEMP, address race condition * - remove openmp from this pr - update name from cache to boostrapcache * add default value of signature macros * remove openmp from cmake file * Update src/allreduce_robust.cc Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu> * Update src/allreduce_robust.cc Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu> * run test with cmake * remove openmp * fix cmake based tests * use cmake test fix darwin .dylib issue * move around rabit_signature definition due to windows build * misc, add c++ check in CMakeFile * per feedback * resolve CMake file * update rabit version
This commit is contained in:
parent
dba32d54d1
commit
5797dcb64e
@ -16,8 +16,7 @@ env:
|
||||
- TASK=doc
|
||||
- TASK=build
|
||||
- TASK=mpi-build
|
||||
- TASK=cmake-build
|
||||
- TASK=test CXX=g++
|
||||
- TASK=cmake-test
|
||||
|
||||
# dependent apt packages
|
||||
addons:
|
||||
|
||||
@ -1,6 +1,14 @@
|
||||
cmake_minimum_required(VERSION 3.0)
|
||||
cmake_minimum_required(VERSION 3.3)
|
||||
|
||||
project(rabit VERSION 0.2.0)
|
||||
project(rabit VERSION 0.3.0)
|
||||
|
||||
include(CheckCXXCompilerFlag)
|
||||
CHECK_CXX_COMPILER_FLAG("-std=c++11" COMPILER_SUPPORTS_CXX11)
|
||||
if(COMPILER_SUPPORTS_CXX11)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
|
||||
else()
|
||||
message(STATUS "The compiler ${CMAKE_CXX_COMPILER} has no C++11 support. Please use a different C++ compiler.")
|
||||
endif()
|
||||
|
||||
option(RABIT_BUILD_TESTS "Build rabit tests" OFF)
|
||||
option(RABIT_BUILD_MPI "Build MPI" OFF)
|
||||
@ -9,10 +17,11 @@ option(RABIT_BUILD_DMLC "Include DMLC_CORE in build" ON)
|
||||
add_library(rabit src/allreduce_base.cc src/allreduce_robust.cc src/engine.cc src/c_api.cc)
|
||||
add_library(rabit_base src/allreduce_base.cc src/engine_base.cc src/c_api.cc)
|
||||
add_library(rabit_empty src/engine_empty.cc src/c_api.cc)
|
||||
add_library(rabit_mock_static src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc)
|
||||
add_library(rabit_mock SHARED src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc)
|
||||
|
||||
set(rabit_libs rabit rabit_base rabit_empty)
|
||||
|
||||
set_target_properties(rabit rabit_base rabit_empty PROPERTIES CXX_STANDARD 11 CXX_STANDARD_REQUIRED ON)
|
||||
set(rabit_libs rabit rabit_base rabit_empty rabit_mock rabit_mock_static)
|
||||
set_target_properties(rabit rabit_base rabit_empty rabit_mock rabit_mock_static PROPERTIES CXX_STANDARD 11 CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
if(RABIT_BUILD_MPI)
|
||||
find_package(MPI REQUIRED)
|
||||
@ -21,11 +30,6 @@ if(RABIT_BUILD_MPI)
|
||||
list(APPEND rabit_libs rabit_mpi)
|
||||
endif()
|
||||
|
||||
if(RABIT_BUILD_TESTS)
|
||||
# Use static so that rabit_mock won't be installed when building shared libraries.
|
||||
add_library(rabit_mock STATIC src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc)
|
||||
list(APPEND rabit_libs rabit_mock) # add to list to apply build settings, then remove
|
||||
endif()
|
||||
|
||||
if(RABIT_BUILD_DMLC)
|
||||
foreach(lib ${rabit_libs})
|
||||
@ -46,17 +50,19 @@ else()
|
||||
endif()
|
||||
|
||||
if(RABIT_BUILD_TESTS)
|
||||
list(REMOVE_ITEM rabit_libs "rabit_mock") # remove here to avoid installing it
|
||||
set(tests speed_test lazy_recover local_recover model_recover)
|
||||
list(REMOVE_ITEM rabit_libs "rabit_mock_static") # remove here to avoid installing it
|
||||
set(tests lazy_recover local_recover model_recover)
|
||||
|
||||
foreach(test ${tests})
|
||||
add_executable(${test} test/${test}.cc)
|
||||
target_link_libraries(${test} rabit_mock)
|
||||
install(TARGETS ${test} DESTINATION bin)
|
||||
target_link_libraries(${test} rabit_mock_static)
|
||||
set_target_properties(${test} PROPERTIES CXX_STANDARD 11 CXX_STANDARD_REQUIRED ON)
|
||||
install(TARGETS ${test} DESTINATION test)
|
||||
endforeach()
|
||||
if(RABIT_BUILD_MPI)
|
||||
add_executable(speed_test_mpi test/speed_test.cc)
|
||||
target_link_libraries(speed_test_mpi rabit_mpi)
|
||||
install(TARGETS speed_test_mpi DESTINATION bin)
|
||||
install(TARGETS speed_test_mpi DESTINATION test)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@ -66,6 +72,7 @@ endif()
|
||||
# * <prefix>/lib/cmake/<PROJECT-NAME>
|
||||
# * <prefix>/lib/
|
||||
# * <prefix>/include/
|
||||
set(CMAKE_INSTALL_PREFIX "../")
|
||||
set(config_install_dir "lib/cmake/${PROJECT_NAME}")
|
||||
set(include_install_dir "include")
|
||||
|
||||
|
||||
54
Makefile
54
Makefile
@ -2,14 +2,8 @@ OS := $(shell uname)
|
||||
|
||||
RABIT_BUILD_DMLC = 0
|
||||
|
||||
ifeq ($(RABIT_BUILD_DMLC),1)
|
||||
DMLC=dmlc-core
|
||||
else
|
||||
DMLC=../dmlc-core
|
||||
endif
|
||||
|
||||
export WARNFLAGS= -Wall -Wextra -Wno-unused-parameter -Wno-unknown-pragmas -std=c++11
|
||||
export CFLAGS = -O3 $(WARNFLAGS) -I $(DMLC)/include -I include/
|
||||
export CFLAGS = -O3 $(WARNFLAGS)
|
||||
export LDFLAGS =-Llib
|
||||
|
||||
#download mpi
|
||||
@ -17,46 +11,16 @@ export LDFLAGS =-Llib
|
||||
|
||||
MPICXX=./mpich/bin/mpicxx
|
||||
|
||||
ifeq ($(OS), Darwin)
|
||||
ifndef CC
|
||||
export CC = gcc-4.9
|
||||
endif
|
||||
ifndef CXX
|
||||
export CXX = g++-4.9
|
||||
endif
|
||||
else
|
||||
ifeq ($(OS), FreeBSD)
|
||||
ifndef CXX
|
||||
export CXX = g++6
|
||||
endif
|
||||
export LDFLAGS= -Llib -Wl,-rpath=/usr/local/lib/gcc6
|
||||
else
|
||||
# linux defaults
|
||||
ifndef CC
|
||||
export CC = gcc
|
||||
endif
|
||||
ifndef CXX
|
||||
export CXX = g++
|
||||
endif
|
||||
LDFLAGS +=-lrt
|
||||
endif
|
||||
endif
|
||||
export CXX = g++
|
||||
|
||||
|
||||
#----------------------------
|
||||
# Settings for power and arm arch
|
||||
#----------------------------
|
||||
ARCH := $(shell uname -a)
|
||||
ifneq (,$(filter $(ARCH), powerpc64le ppc64le ))
|
||||
USE_SSE=0
|
||||
ifneq (,$(filter $(ARCH), armv6l armv7l powerpc64le ppc64le aarch64))
|
||||
CFLAGS += -march=native
|
||||
else
|
||||
USE_SSE=1
|
||||
endif
|
||||
|
||||
ifndef USE_SSE
|
||||
USE_SSE = 1
|
||||
endif
|
||||
|
||||
ifeq ($(USE_SSE), 1)
|
||||
CFLAGS += -msse2
|
||||
endif
|
||||
|
||||
@ -71,6 +35,14 @@ ifndef LINT_LANG
|
||||
LINT_LANG="all"
|
||||
endif
|
||||
|
||||
ifeq ($(RABIT_BUILD_DMLC),1)
|
||||
DMLC=dmlc-core
|
||||
else
|
||||
DMLC=../dmlc-core
|
||||
endif
|
||||
|
||||
CFLAGS += -I $(DMLC)/include -I include/
|
||||
|
||||
# build path
|
||||
BPATH=.
|
||||
# objectives that makes up rabit library
|
||||
|
||||
@ -9,6 +9,40 @@
|
||||
#include <string>
|
||||
#include "../serializable.h"
|
||||
|
||||
// keeps rabit api caller signature
|
||||
#ifndef RABIT_API_CALLER_SIGNATURE
|
||||
#define RABIT_API_CALLER_SIGNATURE
|
||||
|
||||
#ifdef __has_builtin
|
||||
|
||||
#if __has_builtin(__builtin_FILE)
|
||||
#define _FILE __builtin_FILE()
|
||||
#else
|
||||
#define _FILE "N/A"
|
||||
#endif // __has_builtin(__builtin_FILE)
|
||||
|
||||
#if __has_builtin(__builtin_LINE)
|
||||
#define _LINE __builtin_LINE()
|
||||
#else
|
||||
#define _LINE -1
|
||||
#endif // __has_builtin(__builtin_LINE)
|
||||
|
||||
#if __has_builtin(__builtin_FUNCTION)
|
||||
#define _CALLER __builtin_FUNCTION()
|
||||
#else
|
||||
#define _CALLER "N/A"
|
||||
#endif // __has_builtin(__builtin_FUNCTION)
|
||||
|
||||
#else
|
||||
|
||||
#define _FILE "N/A"
|
||||
#define _LINE -1
|
||||
#define _CALLER "N/A"
|
||||
|
||||
#endif // __has_builtin
|
||||
|
||||
#endif // RABIT_API_CALLER_SIGNATURE
|
||||
|
||||
namespace MPI {
|
||||
/*! \brief MPI data type just to be compatible with MPI reduce function*/
|
||||
class Datatype;
|
||||
@ -54,20 +88,36 @@ class IEngine {
|
||||
* will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to pass into the lazy preprocessing function
|
||||
* \param is_bootstrap if this allreduce is needed to bootstrap failed node
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun = NULL,
|
||||
void *prepare_arg = NULL) = 0;
|
||||
void *prepare_arg = NULL,
|
||||
bool is_bootstrap = false,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER) = 0;
|
||||
/*!
|
||||
* \brief broadcasts data from root to every other node
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||
* \param size the size of the data to be broadcasted
|
||||
* \param root the root worker id to broadcast the data
|
||||
* \param is_bootstrap if this broadcast is needed to bootstrap failed node
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) = 0;
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root,
|
||||
bool is_bootstrap = false,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER) = 0;
|
||||
/*!
|
||||
* \brief explicitly re-initialize everything before calling LoadCheckPoint
|
||||
* call this function when IEngine throws an exception,
|
||||
@ -204,6 +254,10 @@ enum DataType {
|
||||
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to pass into the lazy preprocessing function.
|
||||
* \param is_bootstrap if this allreduce is needed to bootstrap failed node
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
void Allreduce_(void *sendrecvbuf,
|
||||
size_t type_nbytes,
|
||||
@ -212,8 +266,11 @@ void Allreduce_(void *sendrecvbuf,
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun = NULL,
|
||||
void *prepare_arg = NULL);
|
||||
|
||||
void *prepare_arg = NULL,
|
||||
bool is_bootstrap = false,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
/*!
|
||||
* \brief handle for customized reducer, used to handle customized reduce
|
||||
* this class is mainly created for compatiblity issues with MPI's customized reduce
|
||||
@ -239,12 +296,20 @@ class ReduceHandle {
|
||||
* will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to pass into the lazy preprocessing function
|
||||
* \param is_bootstrap if this allreduce is needed to bootstrap failed node
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
void Allreduce(void *sendrecvbuf,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
IEngine::PreprocFunction prepare_fun = NULL,
|
||||
void *prepare_arg = NULL);
|
||||
void *prepare_arg = NULL,
|
||||
bool is_bootstrap = false,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
/*! \return the number of bytes occupied by the type */
|
||||
static int TypeSize(const MPI::Datatype &dtype);
|
||||
|
||||
|
||||
@ -96,6 +96,7 @@ template<typename OP, typename DType>
|
||||
inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
|
||||
const DType *src = (const DType*)src_;
|
||||
DType *dst = (DType*)dst_; // NOLINT(*)
|
||||
|
||||
for (int i = 0; i < len; ++i) {
|
||||
OP::Reduce(dst[i], src[i]);
|
||||
}
|
||||
@ -127,28 +128,43 @@ inline std::string GetProcessorName(void) {
|
||||
return engine::GetEngine()->GetHost();
|
||||
}
|
||||
// broadcast data to all other nodes from root
|
||||
inline void Broadcast(void *sendrecv_data, size_t size, int root) {
|
||||
engine::GetEngine()->Broadcast(sendrecv_data, size, root);
|
||||
inline void Broadcast(void *sendrecv_data, size_t size, int root,
|
||||
bool is_bootstrap,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
engine::GetEngine()->Broadcast(sendrecv_data, size, root,
|
||||
is_bootstrap, _file, _line, _caller);
|
||||
}
|
||||
template<typename DType>
|
||||
inline void Broadcast(std::vector<DType> *sendrecv_data, int root) {
|
||||
inline void Broadcast(std::vector<DType> *sendrecv_data, int root,
|
||||
bool is_bootstrap,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
size_t size = sendrecv_data->size();
|
||||
Broadcast(&size, sizeof(size), root);
|
||||
Broadcast(&size, sizeof(size), root, is_bootstrap, _file, _line, _caller);
|
||||
if (sendrecv_data->size() != size) {
|
||||
sendrecv_data->resize(size);
|
||||
}
|
||||
if (size != 0) {
|
||||
Broadcast(&(*sendrecv_data)[0], size * sizeof(DType), root);
|
||||
Broadcast(&(*sendrecv_data)[0], size * sizeof(DType), root,
|
||||
is_bootstrap, _file, _line, _caller);
|
||||
}
|
||||
}
|
||||
inline void Broadcast(std::string *sendrecv_data, int root) {
|
||||
inline void Broadcast(std::string *sendrecv_data, int root,
|
||||
bool is_bootstrap,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
size_t size = sendrecv_data->length();
|
||||
Broadcast(&size, sizeof(size), root);
|
||||
Broadcast(&size, sizeof(size), root, is_bootstrap, _file, _line, _caller);
|
||||
if (sendrecv_data->length() != size) {
|
||||
sendrecv_data->resize(size);
|
||||
}
|
||||
if (size != 0) {
|
||||
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root);
|
||||
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root,
|
||||
is_bootstrap, _file, _line, _caller);
|
||||
}
|
||||
}
|
||||
|
||||
@ -156,9 +172,14 @@ inline void Broadcast(std::string *sendrecv_data, int root) {
|
||||
template<typename OP, typename DType>
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
void *prepare_arg,
|
||||
bool is_bootstrap,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
|
||||
engine::mpi::GetType<DType>(), OP::kType, prepare_fun, prepare_arg);
|
||||
engine::mpi::GetType<DType>(), OP::kType, prepare_fun, prepare_arg,
|
||||
is_bootstrap, _file, _line, _caller);
|
||||
}
|
||||
|
||||
// C++11 support for lambda prepare function
|
||||
@ -167,9 +188,15 @@ inline void InvokeLambda_(void *fun) {
|
||||
(*static_cast<std::function<void()>*>(fun))();
|
||||
}
|
||||
template<typename OP, typename DType>
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> prepare_fun) {
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
std::function<void()> prepare_fun,
|
||||
bool is_bootstrap,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
|
||||
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda_, &prepare_fun);
|
||||
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda_, &prepare_fun,
|
||||
is_bootstrap, _file, _line, _caller);
|
||||
}
|
||||
#endif // C++11
|
||||
|
||||
@ -188,6 +215,7 @@ inline void TrackerPrintf(const char *fmt, ...) {
|
||||
msg.resize(strlen(msg.c_str()));
|
||||
TrackerPrint(msg);
|
||||
}
|
||||
|
||||
#endif // RABIT_STRICT_CXX98_
|
||||
// load latest check point
|
||||
inline int LoadCheckPoint(Serializable *global_model,
|
||||
@ -216,8 +244,8 @@ inline void ReducerSafe_(const void *src_, void *dst_, int len_, const MPI::Data
|
||||
const size_t kUnit = sizeof(DType);
|
||||
const char *psrc = reinterpret_cast<const char*>(src_);
|
||||
char *pdst = reinterpret_cast<char*>(dst_);
|
||||
DType tdst, tsrc;
|
||||
for (int i = 0; i < len_; ++i) {
|
||||
DType tdst, tsrc;
|
||||
// use memcpy to avoid alignment issue
|
||||
std::memcpy(&tdst, pdst + i * kUnit, sizeof(tdst));
|
||||
std::memcpy(&tsrc, psrc + i * kUnit, sizeof(tsrc));
|
||||
@ -247,8 +275,13 @@ inline Reducer<DType, freduce>::Reducer(void) {
|
||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
|
||||
inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
handle_.Allreduce(sendrecvbuf, sizeof(DType), count, prepare_fun, prepare_arg);
|
||||
void *prepare_arg,
|
||||
bool is_bootstrap,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
handle_.Allreduce(sendrecvbuf, sizeof(DType), count, prepare_fun,
|
||||
prepare_arg, is_bootstrap, _file, _line, _caller);
|
||||
}
|
||||
// function to perform reduction for SerializeReducer
|
||||
template<typename DType>
|
||||
@ -256,8 +289,8 @@ inline void SerializeReducerFunc_(const void *src_, void *dst_,
|
||||
int len_, const MPI::Datatype &dtype) {
|
||||
int nbytes = engine::ReduceHandle::TypeSize(dtype);
|
||||
// temp space
|
||||
DType tsrc, tdst;
|
||||
for (int i = 0; i < len_; ++i) {
|
||||
DType tsrc, tdst;
|
||||
utils::MemoryFixSizeBuffer fsrc((char*)(src_) + i * nbytes, nbytes); // NOLINT(*)
|
||||
utils::MemoryFixSizeBuffer fdst((char*)(dst_) + i * nbytes, nbytes); // NOLINT(*)
|
||||
tsrc.Load(fsrc);
|
||||
@ -296,7 +329,11 @@ template<typename DType>
|
||||
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
||||
size_t max_nbyte, size_t count,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
void *prepare_arg,
|
||||
bool is_bootstrap,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
buffer_.resize(max_nbyte * count);
|
||||
// setup closure
|
||||
SerializeReduceClosure<DType> c;
|
||||
@ -304,7 +341,8 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
||||
c.prepare_fun = prepare_fun; c.prepare_arg = prepare_arg; c.p_buffer = &buffer_;
|
||||
// invoke here
|
||||
handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count,
|
||||
SerializeReduceClosure<DType>::Invoke, &c);
|
||||
SerializeReduceClosure<DType>::Invoke, &c,
|
||||
is_bootstrap, _file, _line, _caller);
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte);
|
||||
sendrecvobj[i].Load(fs);
|
||||
@ -314,14 +352,24 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
||||
#if DMLC_USE_CXX11
|
||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)g
|
||||
inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
|
||||
std::function<void()> prepare_fun) {
|
||||
this->Allreduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun);
|
||||
std::function<void()> prepare_fun,
|
||||
bool is_bootstrap,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
this->Allreduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun,
|
||||
is_bootstrap, _file, _line, _caller);
|
||||
}
|
||||
template<typename DType>
|
||||
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
||||
size_t max_nbytes, size_t count,
|
||||
std::function<void()> prepare_fun) {
|
||||
this->Allreduce(sendrecvobj, max_nbytes, count, InvokeLambda_, &prepare_fun);
|
||||
std::function<void()> prepare_fun,
|
||||
bool is_bootstrap,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
this->Allreduce(sendrecvobj, max_nbytes, count, InvokeLambda_, &prepare_fun,
|
||||
is_bootstrap, _file, _line, _caller);
|
||||
}
|
||||
#endif // DMLC_USE_CXX11
|
||||
} // namespace rabit
|
||||
|
||||
@ -96,9 +96,15 @@ inline void HandleCheckError(const char *msg) {
|
||||
inline void HandlePrint(const char *msg) {
|
||||
printf("%s", msg);
|
||||
}
|
||||
inline void HandleLogPrint(const char *msg) {
|
||||
fprintf(stderr, "%s", msg);
|
||||
fflush(stderr);
|
||||
|
||||
inline void HandleLogInfo(const char *fmt, ...) {
|
||||
std::string msg(kPrintBuffer, '\0');
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||
va_end(args);
|
||||
fprintf(stdout, "%s", msg.c_str());
|
||||
fflush(stdout);
|
||||
}
|
||||
#else
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
|
||||
@ -22,6 +22,40 @@
|
||||
#endif // defined(__GXX_EXPERIMENTAL_CXX0X__) || defined(_MSC_VER)
|
||||
#endif // DMLC_USE_CXX11
|
||||
|
||||
// keeps rabit api caller signature
|
||||
#ifndef RABIT_API_CALLER_SIGNATURE
|
||||
#define RABIT_API_CALLER_SIGNATURE
|
||||
|
||||
#ifdef __has_builtin
|
||||
|
||||
#if __has_builtin(__builtin_FILE)
|
||||
#define _FILE __builtin_FILE()
|
||||
#else
|
||||
#define _FILE "N/A"
|
||||
#endif // __has_builtin(__builtin_FILE)
|
||||
|
||||
#if __has_builtin(__builtin_LINE)
|
||||
#define _LINE __builtin_LINE()
|
||||
#else
|
||||
#define _LINE -1
|
||||
#endif // __has_builtin(__builtin_LINE)
|
||||
|
||||
#if __has_builtin(__builtin_FUNCTION)
|
||||
#define _CALLER __builtin_FUNCTION()
|
||||
#else
|
||||
#define _CALLER "N/A"
|
||||
#endif // __has_builtin(__builtin_FUNCTION)
|
||||
|
||||
#else
|
||||
|
||||
#define _FILE "N/A"
|
||||
#define _LINE -1
|
||||
#define _CALLER "N/A"
|
||||
|
||||
#endif // __has_builtin
|
||||
|
||||
#endif // RABIT_API_CALLER_SIGNATURE
|
||||
|
||||
// optionally support of lambda functions in C++11, if available
|
||||
#if DMLC_USE_CXX11
|
||||
#include <functional>
|
||||
@ -101,6 +135,7 @@ inline std::string GetProcessorName();
|
||||
* \param msg the message to be printed
|
||||
*/
|
||||
inline void TrackerPrint(const std::string &msg);
|
||||
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
/*!
|
||||
* \brief prints the msg to the tracker, this function may not be available
|
||||
@ -118,25 +153,50 @@ inline void TrackerPrintf(const char *fmt, ...);
|
||||
* \param sendrecv_data the pointer to the send/receive buffer,
|
||||
* \param size the data size
|
||||
* \param root the process root
|
||||
* \param is_bootstrap if this allreduce is needed to bootstrap failed node
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
inline void Broadcast(void *sendrecv_data, size_t size, int root);
|
||||
inline void Broadcast(void *sendrecv_data, size_t size, int root,
|
||||
bool is_bootstrap = false,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
|
||||
/*!
|
||||
* \brief broadcasts an std::vector<DType> to every node from root
|
||||
* \param sendrecv_data the pointer to send/receive vector,
|
||||
* for the receiver, the vector does not need to be pre-allocated
|
||||
* \param root the process root
|
||||
* \param is_bootstrap if this allreduce is needed to bootstrap failed node
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
* \tparam DType the data type stored in the vector, has to be a simple data type
|
||||
* that can be directly transmitted by sending the sizeof(DType)
|
||||
*/
|
||||
template<typename DType>
|
||||
inline void Broadcast(std::vector<DType> *sendrecv_data, int root);
|
||||
inline void Broadcast(std::vector<DType> *sendrecv_data, int root,
|
||||
bool is_bootstrap = false,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
/*!
|
||||
* \brief broadcasts a std::string to every node from the root
|
||||
* \param sendrecv_data the pointer to the send/receive buffer,
|
||||
* for the receiver, the vector does not need to be pre-allocated
|
||||
* \param is_bootstrap if this allreduce is needed to bootstrap failed node
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
* \param root the process root
|
||||
*/
|
||||
inline void Broadcast(std::string *sendrecv_data, int root);
|
||||
inline void Broadcast(std::string *sendrecv_data, int root,
|
||||
bool is_bootstrap = false,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
/*!
|
||||
* \brief performs in-place Allreduce on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
@ -155,13 +215,22 @@ inline void Broadcast(std::string *sendrecv_data, int root);
|
||||
* will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to pass into the lazy preprocessing function
|
||||
* \param is_bootstrap if this allreduce is needed to bootstrap filed node
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
* \tparam OP see namespace op, reduce operator
|
||||
* \tparam DType data type
|
||||
*/
|
||||
template<typename OP, typename DType>
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
void (*prepare_fun)(void *) = NULL,
|
||||
void *prepare_arg = NULL);
|
||||
void *prepare_arg = NULL,
|
||||
bool is_bootstrap = false,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
|
||||
// C++11 support for lambda prepare function
|
||||
#if DMLC_USE_CXX11
|
||||
/*!
|
||||
@ -185,12 +254,20 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
* \param prepare_fun Lazy lambda preprocessing function, prepare_fun() will be invoked
|
||||
* by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param is_bootstrap if this allreduce is needed to bootstrap failed node
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
* \tparam OP see namespace op, reduce operator
|
||||
* \tparam DType data type
|
||||
*/
|
||||
template<typename OP, typename DType>
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
std::function<void()> prepare_fun);
|
||||
std::function<void()> prepare_fun,
|
||||
bool is_bootstrap = false,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
#endif // C++11
|
||||
/*!
|
||||
* \brief loads the latest check point
|
||||
@ -286,19 +363,35 @@ class Reducer {
|
||||
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to pass into the lazy preprocessing function
|
||||
* \param is_bootstrap if this allreduce is needed to bootstrap filed node
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
void (*prepare_fun)(void *) = NULL,
|
||||
void *prepare_arg = NULL);
|
||||
void *prepare_arg = NULL,
|
||||
bool is_bootstrap = false,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
#if DMLC_USE_CXX11
|
||||
/*!
|
||||
* \brief customized in-place all reduce operation, with lambda function as preprocessor
|
||||
* \param sendrecvbuf pointer to the array of objects to be reduced
|
||||
* \param count number of elements to be reduced
|
||||
* \param prepare_fun lambda function executed to prepare the data, if necessary
|
||||
* \param is_bootstrap if this allreduce is needed to bootstrap filed node
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
std::function<void()> prepare_fun);
|
||||
std::function<void()> prepare_fun,
|
||||
bool is_bootstrap = false,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
#endif // DMLC_USE_CXX11
|
||||
|
||||
private:
|
||||
@ -329,11 +422,19 @@ class SerializeReducer {
|
||||
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf.
|
||||
* If the result of Allreduce can be recovered directly, then the prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to pass into the lazy preprocessing function
|
||||
* \param is_bootstrap if this allreduce is needed to bootstrap failed node
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
inline void Allreduce(DType *sendrecvobj,
|
||||
size_t max_nbyte, size_t count,
|
||||
void (*prepare_fun)(void *) = NULL,
|
||||
void *prepare_arg = NULL);
|
||||
void *prepare_arg = NULL,
|
||||
bool is_bootstrap = false,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
// C++11 support for lambda prepare function
|
||||
#if DMLC_USE_CXX11
|
||||
/*!
|
||||
@ -343,10 +444,18 @@ class SerializeReducer {
|
||||
* this includes budget limit for intermediate and final result
|
||||
* \param count number of elements to be reduced
|
||||
* \param prepare_fun lambda function executed to prepare the data, if necessary
|
||||
* \param is_bootstrap if this allreduce is needed to bootstrap failed node
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
inline void Allreduce(DType *sendrecvobj,
|
||||
size_t max_nbyte, size_t count,
|
||||
std::function<void()> prepare_fun);
|
||||
std::function<void()> prepare_fun,
|
||||
bool is_bootstrap = false,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
#endif // DMLC_USE_CXX11
|
||||
|
||||
private:
|
||||
|
||||
@ -7,6 +7,7 @@ Author: Tianqi Chen
|
||||
import pickle
|
||||
import ctypes
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import warnings
|
||||
import numpy as np
|
||||
@ -62,6 +63,8 @@ def _loadlib(lib='standard', lib_dll=None):
|
||||
|
||||
if os.name == 'nt':
|
||||
dll_name += '.dll'
|
||||
elif platform.system() == 'Darwin':
|
||||
dll_name += '.dylib'
|
||||
else:
|
||||
dll_name += '.so'
|
||||
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
make -f test.mk model_recover_10_10k || exit -1
|
||||
make -f test.mk model_recover_10_10k_die_same || exit -1
|
||||
make -f test.mk model_recover_10_10k_die_hard || exit -1
|
||||
make -f test.mk local_recover_10_10k || exit -1
|
||||
make -f test.mk lazy_recover_10_10k_die_hard || exit -1
|
||||
make -f test.mk lazy_recover_10_10k_die_same || exit -1
|
||||
make -f test.mk ringallreduce_10_10k || exit -1
|
||||
make -f test.mk pylocal_recover_10_10k || exit -1
|
||||
make -f test.mk RABIT_BUILD_DMLC=1 model_recover_10_10k || exit -1
|
||||
make -f test.mk RABIT_BUILD_DMLC=1 model_recover_10_10k_die_same || exit -1
|
||||
make -f test.mk RABIT_BUILD_DMLC=1 model_recover_10_10k_die_hard || exit -1
|
||||
make -f test.mk RABIT_BUILD_DMLC=1 local_recover_10_10k || exit -1
|
||||
make -f test.mk RABIT_BUILD_DMLC=1 lazy_recover_10_10k_die_hard || exit -1
|
||||
make -f test.mk RABIT_BUILD_DMLC=1 lazy_recover_10_10k_die_same || exit -1
|
||||
make -f test.mk RABIT_BUILD_DMLC=1 ringallreduce_10_10k || exit -1
|
||||
make -f test.mk RABIT_BUILD_DMLC=1 pylocal_recover_10_10k || exit -1
|
||||
|
||||
@ -10,6 +10,7 @@ if [ ${TASK} == "doc" ]; then
|
||||
(cat log.txt| grep -v ENABLE_PREPROCESSING |grep -v "unsupported tag" |grep warning) && exit -1
|
||||
fi
|
||||
|
||||
# we should depreciate Makefile based build
|
||||
if [ ${TASK} == "build" ]; then
|
||||
make all RABIT_BUILD_DMLC=1 || exit -1
|
||||
fi
|
||||
@ -19,16 +20,13 @@ if [ ${TASK} == "mpi-build" ]; then
|
||||
cd test
|
||||
make mpi RABIT_BUILD_DMLC=1 && make speed_test.mpi RABIT_BUILD_DMLC=1 || exit -1
|
||||
fi
|
||||
|
||||
if [ ${TASK} == "test" ]; then
|
||||
cd test
|
||||
make all RABIT_BUILD_DMLC=1 || exit -1
|
||||
../scripts/travis_runtest.sh || exit -1
|
||||
fi
|
||||
|
||||
if [ ${TASK} == "cmake-build" ]; then
|
||||
#
|
||||
if [ ${TASK} == "cmake-test" ]; then
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. -DRABIT_BUILD_DMLC=ON
|
||||
make all || exit -1
|
||||
cmake -DRABIT_BUILD_TESTS=ON -DRABIT_BUILD_DMLC=ON ..
|
||||
make install || exit -1
|
||||
cd ../test
|
||||
../scripts/travis_runtest.sh || exit -1
|
||||
rm -rf ../build
|
||||
fi
|
||||
@ -46,6 +46,8 @@ AllreduceBase::AllreduceBase(void) {
|
||||
env_vars.push_back("rabit_reduce_ring_mincount");
|
||||
env_vars.push_back("rabit_tracker_uri");
|
||||
env_vars.push_back("rabit_tracker_port");
|
||||
env_vars.push_back("rabit_bootstrap_cache");
|
||||
env_vars.push_back("rabit_debug");
|
||||
// also include dmlc support direct variables
|
||||
env_vars.push_back("DMLC_TASK_ID");
|
||||
env_vars.push_back("DMLC_ROLE");
|
||||
@ -114,6 +116,7 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
|
||||
", quit this program by exit 0\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// clear the setting before start reconnection
|
||||
this->rank = -1;
|
||||
//---------------------
|
||||
@ -147,6 +150,7 @@ bool AllreduceBase::Shutdown(void) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void AllreduceBase::TrackerPrint(const std::string &msg) {
|
||||
if (tracker_uri == "NULL") {
|
||||
utils::Printf("%s", msg.c_str()); return;
|
||||
@ -156,6 +160,7 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
|
||||
tracker.SendStr(msg);
|
||||
tracker.Close();
|
||||
}
|
||||
|
||||
// util to parse data with unit suffix
|
||||
inline size_t ParseUnit(const char *name, const char *val) {
|
||||
char unit;
|
||||
@ -211,6 +216,12 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
throw std::runtime_error("invalid value of DMLC_WORKER_STOP_PROCESS_ON_ERROR");
|
||||
}
|
||||
}
|
||||
if (!strcmp(name, "rabit_bootstrap_cache")) {
|
||||
rabit_bootstrap_cache = atoi(val);
|
||||
}
|
||||
if (!strcmp(name, "rabit_debug")) {
|
||||
rabit_debug = atoi(val);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief initialize connection to the tracker
|
||||
@ -283,6 +294,10 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
Assert(rank == -1 || newrank == rank,
|
||||
"must keep rank to same if the node already have one");
|
||||
rank = newrank;
|
||||
|
||||
// tracker got overwhelemed and not able to assign correct rank
|
||||
if (rank == -1) exit(-1);
|
||||
|
||||
Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
|
||||
sizeof(num_neighbors), "ReConnectLink failure 4");
|
||||
for (int i = 0; i < num_neighbors; ++i) {
|
||||
|
||||
@ -54,6 +54,7 @@ class AllreduceBase : public IEngine {
|
||||
* \param msg message to be printed in the tracker
|
||||
*/
|
||||
virtual void TrackerPrint(const std::string &msg);
|
||||
|
||||
/*! \brief get rank */
|
||||
virtual int GetRank(void) const {
|
||||
return rank;
|
||||
@ -82,13 +83,21 @@ class AllreduceBase : public IEngine {
|
||||
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
* \param is_bootstrap if this allreduce is needed to bootstrap filed node
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun = NULL,
|
||||
void *prepare_arg = NULL) {
|
||||
void *prepare_arg = NULL,
|
||||
bool is_bootstrap = false,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER) {
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
if (world_size == 1 || world_size == -1) return;
|
||||
utils::Assert(TryAllreduce(sendrecvbuf_,
|
||||
@ -100,8 +109,14 @@ class AllreduceBase : public IEngine {
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param size the size of the data to be broadcasted
|
||||
* \param root the root worker id to broadcast the data
|
||||
* \param is_bootstrap if this broadcast is needed to bootstrap filed node
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||
bool is_bootstrap = false, const char* _file = _FILE,
|
||||
const int _line = _LINE, const char* _caller = _CALLER) {
|
||||
if (world_size == 1 || world_size == -1) return;
|
||||
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
|
||||
"Broadcast failed");
|
||||
@ -525,6 +540,10 @@ class AllreduceBase : public IEngine {
|
||||
utils::TCPSocket sock_listen;
|
||||
// backdoor port
|
||||
int port = 0;
|
||||
// enable bootstrap cache 0 false 1 true
|
||||
int rabit_bootstrap_cache = 0;
|
||||
// enable detailed logging
|
||||
int rabit_debug = 0;
|
||||
};
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
#include "../include/rabit/internal/io.h"
|
||||
#include "../include/rabit/internal/timer.h"
|
||||
#include "../include/rabit/internal/utils.h"
|
||||
#include "../include/rabit/internal/engine.h"
|
||||
#include "../include/rabit/internal/rabit-inl.h"
|
||||
@ -23,6 +24,7 @@ AllreduceRobust::AllreduceRobust(void) {
|
||||
num_global_replica = 5;
|
||||
default_local_replica = 2;
|
||||
seq_counter = 0;
|
||||
cur_cache_seq = 0;
|
||||
local_chkpt_version = 0;
|
||||
result_buffer_round = 1;
|
||||
global_lazycheck = NULL;
|
||||
@ -33,6 +35,9 @@ AllreduceRobust::AllreduceRobust(void) {
|
||||
}
|
||||
bool AllreduceRobust::Init(int argc, char* argv[]) {
|
||||
if (AllreduceBase::Init(argc, argv)) {
|
||||
// chenqin: alert user opted in experimental feature.
|
||||
if (rabit_bootstrap_cache) utils::HandleLogInfo(
|
||||
"[EXPERIMENTAL] rabit bootstrap cache has been enabled\n");
|
||||
if (num_global_replica == 0) {
|
||||
result_buffer_round = -1;
|
||||
} else {
|
||||
@ -48,19 +53,18 @@ bool AllreduceRobust::Shutdown(void) {
|
||||
try {
|
||||
// need to sync the exec before we shutdown, do a pesudo check point
|
||||
// execute checkpoint, note: when checkpoint existing, load will not happen
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp),
|
||||
"Shutdown: check point must return true");
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp,
|
||||
cur_cache_seq), "Shutdown: check point must return true");
|
||||
// reset result buffer
|
||||
resbuf.Clear();
|
||||
seq_counter = 0;
|
||||
resbuf.Clear(); seq_counter = 0;
|
||||
cachebuf.Clear(); cur_cache_seq = 0;
|
||||
lookupbuf.Clear();
|
||||
// execute check ack step, load happens here
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
|
||||
"Shutdown: check ack must return true");
|
||||
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck,
|
||||
ActionSummary::kSpecialOp, cur_cache_seq), "Shutdown: check ack must return true");
|
||||
#if defined (__APPLE__)
|
||||
sleep(1);
|
||||
#endif
|
||||
|
||||
return AllreduceBase::Shutdown();
|
||||
} catch (const std::exception& e) {
|
||||
fprintf(stderr, "%s\n", e.what());
|
||||
@ -79,6 +83,69 @@ void AllreduceRobust::SetParam(const char *name, const char *val) {
|
||||
num_local_replica = atoi(val);
|
||||
}
|
||||
}
|
||||
|
||||
int AllreduceRobust::SetBootstrapCache(const std::string &key, const void *buf,
|
||||
const size_t type_nbytes, const size_t count) {
|
||||
int index = -1;
|
||||
for (int i = 0 ; i < cur_cache_seq; i++) {
|
||||
size_t nsize = 0;
|
||||
void* name = lookupbuf.Query(i, &nsize);
|
||||
if (nsize == key.length() + 1
|
||||
&& strcmp(static_cast<const char*>(name), key.c_str()) == 0) {
|
||||
index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
utils::Assert(index == -1, "immutable cache key already exists");
|
||||
utils::Assert(type_nbytes*count > 0, "can't set empty cache");
|
||||
void* temp = cachebuf.AllocTemp(type_nbytes, count);
|
||||
cachebuf.PushTemp(cur_cache_seq, type_nbytes, count);
|
||||
std::memcpy(temp, buf, type_nbytes*count);
|
||||
|
||||
std::string k(key);
|
||||
void* name = lookupbuf.AllocTemp(strlen(k.c_str()) + 1, 1);
|
||||
lookupbuf.PushTemp(cur_cache_seq, strlen(k.c_str()) + 1, 1);
|
||||
std::memcpy(name, key.c_str(), strlen(k.c_str()) + 1);
|
||||
cur_cache_seq += 1;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf,
|
||||
const size_t type_nbytes, const size_t count, const bool byref) {
|
||||
// as requester sync with rest of nodes on latest cache content
|
||||
if (!RecoverExec(NULL, 0, ActionSummary::kLoadBootstrapCache,
|
||||
seq_counter, cur_cache_seq)) return -1;
|
||||
|
||||
int index = -1;
|
||||
for (int i = 0 ; i < cur_cache_seq; i++) {
|
||||
size_t nsize = 0;
|
||||
void* name = lookupbuf.Query(i, &nsize);
|
||||
if (nsize == strlen(key.c_str()) + 1
|
||||
&& strcmp(reinterpret_cast<char*>(name), key.c_str()) == 0) {
|
||||
index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// cache doesn't exists
|
||||
if (index == -1) return -1;
|
||||
|
||||
size_t siz = 0;
|
||||
void* temp = cachebuf.Query(index, &siz);
|
||||
utils::Assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index");
|
||||
utils::Assert(siz == type_nbytes*count, "cache size stored expected to be same as requested");
|
||||
utils::Assert(siz > 0, "cache size should be greater than 0");
|
||||
|
||||
// immutable cache, save copy time by pointer manipulation
|
||||
if (byref) {
|
||||
buf = temp;
|
||||
} else {
|
||||
std::memcpy(buf, temp, type_nbytes*count);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
@ -90,25 +157,44 @@ void AllreduceRobust::SetParam(const char *name, const char *val) {
|
||||
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
* \param is_bootstrap if this allreduce is needed to bootstrap filed node
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
void AllreduceRobust::Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
void *prepare_arg,
|
||||
bool is_bootstrap,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
// skip action in single node
|
||||
if (world_size == 1 || world_size == -1) {
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
return;
|
||||
}
|
||||
bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
|
||||
// now we are free to remove the last result, if any
|
||||
|
||||
// genreate unique allreduce signature
|
||||
std::string key = std::string(_file) + "::" + std::to_string(_line) + "::"
|
||||
+ std::string(_caller) + "#" +std::to_string(type_nbytes) + "x" + std::to_string(count);
|
||||
|
||||
// try fetch bootstrap allreduce results from cache
|
||||
if (is_bootstrap && rabit_bootstrap_cache &&
|
||||
GetBootstrapCache(key, sendrecvbuf_, type_nbytes, count, true) != -1) return;
|
||||
|
||||
double start = utils::GetTime();
|
||||
bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter, cur_cache_seq);
|
||||
|
||||
if (resbuf.LastSeqNo() != -1 &&
|
||||
(result_buffer_round == -1 ||
|
||||
resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
|
||||
(result_buffer_round == -1 ||
|
||||
resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
|
||||
resbuf.DropLast();
|
||||
}
|
||||
|
||||
if (!recovered && prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
void *temp = resbuf.AllocTemp(type_nbytes, count);
|
||||
while (true) {
|
||||
@ -119,23 +205,51 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_,
|
||||
if (CheckAndRecover(TryAllreduce(temp, type_nbytes, count, reducer))) {
|
||||
std::memcpy(sendrecvbuf_, temp, type_nbytes * count); break;
|
||||
} else {
|
||||
recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
|
||||
recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter, cur_cache_seq);
|
||||
}
|
||||
}
|
||||
}
|
||||
resbuf.PushTemp(seq_counter, type_nbytes, count);
|
||||
seq_counter += 1;
|
||||
double delta = utils::GetTime() - start;
|
||||
// log allreduce latency
|
||||
if (rabit_debug) {
|
||||
utils::HandleLogInfo("[%d] allreduce (%s) finished version %d, seq %d, take %f seconds\n",
|
||||
rank, key.c_str(), version_number, seq_counter, delta);
|
||||
}
|
||||
|
||||
// if bootstrap allreduce, store and fetch through cache
|
||||
if (!is_bootstrap || !rabit_bootstrap_cache) {
|
||||
resbuf.PushTemp(seq_counter, type_nbytes, count);
|
||||
seq_counter += 1;
|
||||
} else {
|
||||
SetBootstrapCache(key, sendrecvbuf_, type_nbytes, count);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief broadcast data from root to all nodes
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param size the size of the data to be broadcasted
|
||||
* \param root the root worker id to broadcast the data
|
||||
* \param is_bootstrap if this allreduce is needed to bootstrap filed node
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||
bool is_bootstrap,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
// skip action in single node
|
||||
if (world_size == 1 || world_size == -1) return;
|
||||
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
|
||||
// genreate unique cache signature
|
||||
std::string key = std::string(_file) + "::" + std::to_string(_line) + "::"
|
||||
+ std::string(_caller) + "#" +std::to_string(total_size) + "@" + std::to_string(root);
|
||||
// try fetch bootstrap allreduce results from cache
|
||||
if (is_bootstrap && rabit_bootstrap_cache &&
|
||||
GetBootstrapCache(key, sendrecvbuf_, total_size, 1, true) != -1) return;
|
||||
|
||||
double start = utils::GetTime();
|
||||
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter, cur_cache_seq);
|
||||
// now we are free to remove the last result, if any
|
||||
if (resbuf.LastSeqNo() != -1 &&
|
||||
(result_buffer_round == -1 ||
|
||||
@ -150,12 +264,25 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
|
||||
if (CheckAndRecover(TryBroadcast(sendrecvbuf_, total_size, root))) {
|
||||
std::memcpy(temp, sendrecvbuf_, total_size); break;
|
||||
} else {
|
||||
recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
|
||||
recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter, cur_cache_seq);
|
||||
}
|
||||
}
|
||||
}
|
||||
resbuf.PushTemp(seq_counter, 1, total_size);
|
||||
seq_counter += 1;
|
||||
|
||||
double delta = utils::GetTime() - start;
|
||||
// log broadcast latency
|
||||
if (rabit_debug) {
|
||||
utils::HandleLogInfo(
|
||||
"[%d] broadcast (%s) root %d finished version %d,seq %d, take %f seconds\n",
|
||||
rank, key.c_str(), root, version_number, seq_counter, delta);
|
||||
}
|
||||
// if bootstrap broadcast, store and fetch through cache
|
||||
if (!is_bootstrap || !rabit_bootstrap_cache) {
|
||||
resbuf.PushTemp(seq_counter, 1, total_size);
|
||||
seq_counter += 1;
|
||||
} else {
|
||||
SetBootstrapCache(key, sendrecvbuf_, total_size, 1);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief load latest check point
|
||||
@ -188,8 +315,9 @@ int AllreduceRobust::LoadCheckPoint(Serializable *global_model,
|
||||
utils::Check(local_model == NULL,
|
||||
"need to set rabit_local_replica larger than 1 to checkpoint local_model");
|
||||
}
|
||||
// check if we succesful
|
||||
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp)) {
|
||||
double start = utils::GetTime();
|
||||
// check if we succeed
|
||||
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp, cur_cache_seq)) {
|
||||
int nlocal = std::max(static_cast<int>(local_rptr[local_chkpt_version].size()) - 1, 0);
|
||||
if (local_model != NULL) {
|
||||
if (nlocal == num_local_replica + 1) {
|
||||
@ -215,10 +343,26 @@ int AllreduceRobust::LoadCheckPoint(Serializable *global_model,
|
||||
"local model inconsistent, nlocal=%d", nlocal);
|
||||
}
|
||||
// run another phase of check ack, if recovered from data
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
|
||||
"check ack must return true");
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck,
|
||||
ActionSummary::kSpecialOp, cur_cache_seq), "check ack must return true");
|
||||
|
||||
if (!RecoverExec(NULL, 0, ActionSummary::kLoadBootstrapCache, seq_counter, cur_cache_seq)) {
|
||||
utils::Printf("no need to load cache\n");
|
||||
}
|
||||
double delta = utils::GetTime() - start;
|
||||
|
||||
// log broadcast latency
|
||||
if (rabit_debug) {
|
||||
utils::HandleLogInfo("[%d] loadcheckpoint size %ld finished version %d, "
|
||||
"seq %d, take %f seconds\n",
|
||||
rank, global_checkpoint.length(),
|
||||
version_number, seq_counter, delta);
|
||||
}
|
||||
return version_number;
|
||||
} else {
|
||||
// log job fresh start
|
||||
if (rabit_debug) utils::HandleLogInfo("[%d] loadcheckpoint reset\n", rank);
|
||||
|
||||
// reset result buffer
|
||||
resbuf.Clear(); seq_counter = 0; version_number = 0;
|
||||
// nothing loaded, a fresh start, everyone init model
|
||||
@ -269,6 +413,7 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model,
|
||||
if (world_size == 1) {
|
||||
version_number += 1; return;
|
||||
}
|
||||
double start = utils::GetTime();
|
||||
this->LocalModelCheck(local_model != NULL);
|
||||
if (num_local_replica == 0) {
|
||||
utils::Check(local_model == NULL,
|
||||
@ -297,7 +442,8 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model,
|
||||
local_chkpt_version = !local_chkpt_version;
|
||||
}
|
||||
// execute checkpoint, note: when checkpoint existing, load will not happen
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp),
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint,
|
||||
ActionSummary::kSpecialOp, cur_cache_seq),
|
||||
"check point must return true");
|
||||
// this is the critical region where we will change all the stored models
|
||||
// increase version number
|
||||
@ -306,18 +452,32 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model,
|
||||
if (lazy_checkpt) {
|
||||
global_lazycheck = global_model;
|
||||
} else {
|
||||
printf("[%d] save global checkpoint #%d \n", this->rank, version_number);
|
||||
global_checkpoint.resize(0);
|
||||
utils::MemoryBufferStream fs(&global_checkpoint);
|
||||
fs.Write(&version_number, sizeof(version_number));
|
||||
global_model->Save(&fs);
|
||||
global_lazycheck = NULL;
|
||||
}
|
||||
// reset result buffer
|
||||
double delta = utils::GetTime() - start;
|
||||
// log checkpoint latency
|
||||
if (rabit_debug) {
|
||||
utils::HandleLogInfo(
|
||||
"[%d] checkpoint finished version %d,seq %d, take %f seconds\n",
|
||||
rank, version_number, seq_counter, delta);
|
||||
}
|
||||
start = utils::GetTime();
|
||||
// reset result buffer, mark boostrap phase complete
|
||||
resbuf.Clear(); seq_counter = 0;
|
||||
// execute check ack step, load happens here
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
|
||||
"check ack must return true");
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck,
|
||||
ActionSummary::kSpecialOp, cur_cache_seq), "check ack must return true");
|
||||
|
||||
delta = utils::GetTime() - start;
|
||||
// log checkpoint ack latency
|
||||
if (rabit_debug) {
|
||||
utils::HandleLogInfo("[%d] checkpoint ack finished version %d, take %f seconds\n",
|
||||
rank, version_number, delta);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief reset the all the existing links by sending Out-of-Band message marker
|
||||
@ -557,6 +717,7 @@ AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role,
|
||||
{
|
||||
// get the shortest distance to the request point
|
||||
std::vector<std::pair<int, size_t> > dist_in, dist_out;
|
||||
|
||||
ReturnType succ = MsgPassing(std::make_pair(role == kHaveData, *p_size),
|
||||
&dist_in, &dist_out, ShortestDist);
|
||||
if (succ != kSuccess) return succ;
|
||||
@ -723,6 +884,58 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
/*!
|
||||
* \brief try to fetch allreduce/broadcast results from rest of nodes
|
||||
* as collaberative function called by all nodes, only requester node
|
||||
* will pass seqno to rest of nodes and reconstruct/backfill sendrecvbuf_
|
||||
* of specific seqno from other nodes.
|
||||
*/
|
||||
AllreduceRobust::ReturnType AllreduceRobust::TryRestoreCache(bool requester,
|
||||
const int min_seq, const int max_seq) {
|
||||
// clear requester and rebuild from those with most cache entries
|
||||
if (requester) {
|
||||
utils::Assert(cur_cache_seq <= max_seq, "requester is expected to have fewer cache entries");
|
||||
cachebuf.Clear();
|
||||
lookupbuf.Clear();
|
||||
cur_cache_seq = 0;
|
||||
}
|
||||
RecoverType role = requester ? kRequestData : kHaveData;
|
||||
size_t size = 1;
|
||||
int recv_link;
|
||||
std::vector<bool> req_in;
|
||||
ReturnType ret = TryDecideRouting(role, &size, &recv_link, &req_in);
|
||||
if (ret != kSuccess) return ret;
|
||||
// only recover missing cache entries in requester
|
||||
// as tryrecoverdata is collective call, need to go through entire cache
|
||||
// and only work on those missing
|
||||
for (int i = 0; i < max_seq; i++) {
|
||||
// restore lookup map
|
||||
size_t cache_size = 0;
|
||||
void* key = lookupbuf.Query(i, &cache_size);
|
||||
ret = TryRecoverData(role, &cache_size, sizeof(size_t), recv_link, req_in);
|
||||
if (ret != kSuccess) return ret;
|
||||
if (requester) {
|
||||
key = lookupbuf.AllocTemp(cache_size, 1);
|
||||
lookupbuf.PushTemp(i, cache_size, 1);
|
||||
}
|
||||
ret = TryRecoverData(role, key, cache_size, recv_link, req_in);
|
||||
if (ret != kSuccess) return ret;
|
||||
// restore cache content
|
||||
cache_size = 0;
|
||||
void* buf = cachebuf.Query(i, &cache_size);
|
||||
ret = TryRecoverData(role, &cache_size, sizeof(size_t), recv_link, req_in);
|
||||
if (requester) {
|
||||
buf = cachebuf.AllocTemp(cache_size, 1);
|
||||
cachebuf.PushTemp(i, cache_size, 1);
|
||||
cur_cache_seq +=1;
|
||||
}
|
||||
ret = TryRecoverData(role, buf, cache_size, recv_link, req_in);
|
||||
if (ret != kSuccess) return ret;
|
||||
}
|
||||
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief try to load check point
|
||||
*
|
||||
@ -748,9 +961,6 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
|
||||
succ = TryRecoverLocalState(&local_rptr[local_chkpt_version],
|
||||
&local_chkpt[local_chkpt_version]);
|
||||
if (succ != kSuccess) return succ;
|
||||
|
||||
printf("[%d] recovered from local checkpoint version %d \n", this->rank, local_chkpt_version);
|
||||
|
||||
int nlocal = std::max(static_cast<int>(local_rptr[local_chkpt_version].size()) - 1, 0);
|
||||
// check if everyone is OK
|
||||
unsigned state = 0;
|
||||
@ -817,6 +1027,7 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
|
||||
"TryGetResult::Checkpoint");
|
||||
return TryRecoverLocalState(&local_rptr[new_version], &local_chkpt[new_version]);
|
||||
}
|
||||
|
||||
// handles normal data recovery
|
||||
RecoverType role;
|
||||
if (!requester) {
|
||||
@ -857,18 +1068,28 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
|
||||
* result by recovering procedure, the action is complete, no further action is needed
|
||||
* - false means this is the lastest action that has not yet been executed, need to execute the action
|
||||
*/
|
||||
bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
|
||||
if (flag != 0) {
|
||||
bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
|
||||
int cache_seqno, const char* caller) {
|
||||
// kLoadBootstrapCache should be treated similar as allreduce
|
||||
// when loadcheck/check/checkack runs in other nodes
|
||||
if (flag != 0 && flag != ActionSummary::kLoadBootstrapCache) {
|
||||
utils::Assert(seqno == ActionSummary::kSpecialOp, "must only set seqno for normal operations");
|
||||
}
|
||||
// request
|
||||
ActionSummary req(flag, seqno);
|
||||
|
||||
std::string msg = std::string(caller) + " pass negative seqno "
|
||||
+ std::to_string(seqno) + " flag " + std::to_string(flag)
|
||||
+ " version " + std::to_string(version_number);
|
||||
utils::Assert(seqno >=0, msg.c_str());
|
||||
|
||||
ActionSummary req(flag, flag, seqno, cache_seqno);
|
||||
|
||||
while (true) {
|
||||
this->ReportStatus();
|
||||
// action
|
||||
// copy to action and send to allreduce with other nodes
|
||||
ActionSummary act = req;
|
||||
// get the reduced action
|
||||
if (!CheckAndRecover(TryAllreduce(&act, sizeof(act), 1, ActionSummary::Reducer))) continue;
|
||||
|
||||
if (act.check_ack()) {
|
||||
if (act.check_point()) {
|
||||
// if we also have check_point, do check point first
|
||||
@ -891,9 +1112,49 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
|
||||
} else {
|
||||
if (act.check_point()) {
|
||||
if (act.diff_seq()) {
|
||||
utils::Assert(act.min_seqno() != ActionSummary::kSpecialOp, "min seq bug");
|
||||
bool requester = req.min_seqno() == act.min_seqno();
|
||||
if (!CheckAndRecover(TryGetResult(buf, size, act.min_seqno(), requester))) continue;
|
||||
utils::Assert(act.seqno() != ActionSummary::kSpecialOp, "min seq bug");
|
||||
// print checkpoint consensus flag if user turn on debug
|
||||
if (rabit_debug) {
|
||||
req.print_flags(rank, "checkpoint req");
|
||||
act.print_flags(rank, "checkpoint act");
|
||||
}
|
||||
/*
|
||||
* Chen Qin
|
||||
* at least one hit checkpoint_ code & at least one not hitting
|
||||
* compare with version_number of req.check_point() set true with rest
|
||||
* expect to be equal, means rest fall behind in sequence
|
||||
* use resbuf resbuf to recover
|
||||
* worker-0 worker-1
|
||||
* checkpoint(n-1) checkpoint(n-1)
|
||||
* allreduce allreduce (requester) |
|
||||
* broadcast V
|
||||
* checkpoint(n req)
|
||||
* after catch up to checkpoint n, diff_seq will be false
|
||||
* */
|
||||
// assume requester is falling behind
|
||||
bool requester = req.seqno() == act.seqno();
|
||||
// if not load cache
|
||||
if (!act.load_cache()) {
|
||||
if (act.seqno() > 0) {
|
||||
if (!requester) {
|
||||
utils::Assert(req.check_point(), "checkpoint node should be KHaveData role");
|
||||
buf = resbuf.Query(act.seqno(), &size);
|
||||
utils::Assert(buf != NULL, "buf should have data from resbuf");
|
||||
utils::Assert(size > 0, "buf size should be greater than 0");
|
||||
}
|
||||
if (!CheckAndRecover(TryGetResult(buf, size, act.seqno(), requester))) continue;
|
||||
}
|
||||
} else {
|
||||
// cache seq no should be smaller than kSpecialOp
|
||||
utils::Assert(act.seqno(SeqType::kCache) != ActionSummary::kSpecialOp,
|
||||
"checkpoint with kSpecialOp");
|
||||
int max_cache_seq = cur_cache_seq;
|
||||
if (TryAllreduce(&max_cache_seq, sizeof(max_cache_seq), 1,
|
||||
op::Reducer<op::Max, unsigned>) != kSuccess) continue;
|
||||
|
||||
if (TryRestoreCache(req.load_cache(), act.seqno(), max_cache_seq)
|
||||
!= kSuccess) continue;
|
||||
}
|
||||
if (requester) return true;
|
||||
} else {
|
||||
// no difference in seq no, means we are free to check point
|
||||
@ -909,11 +1170,43 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
|
||||
// if requested load check, then misson complete
|
||||
if (req.load_check()) return true;
|
||||
} else {
|
||||
// run all nodes in a isolated cache restore logic
|
||||
if (act.load_cache()) {
|
||||
// print checkpoint consensus flag if user turn on debug
|
||||
if (rabit_debug) {
|
||||
req.print_flags(rank, "loadcache req");
|
||||
act.print_flags(rank, "loadcache act");
|
||||
}
|
||||
// load cache should not running in parralel with other states
|
||||
utils::Assert(!act.load_check(),
|
||||
"load cache state expect no nodes doing load checkpoint");
|
||||
utils::Assert(!act.check_point() ,
|
||||
"load cache state expect no nodes doing checkpoint");
|
||||
utils::Assert(!act.check_ack(),
|
||||
"load cache state expect no nodes doing checkpoint ack");
|
||||
|
||||
// if all nodes are requester in load cache, skip
|
||||
if (act.load_cache(SeqType::kCache)) return false;
|
||||
|
||||
// only restore when at least one pair of max_seq are different
|
||||
if (act.diff_seq(SeqType::kCache)) {
|
||||
// if restore cache failed, retry from what's left
|
||||
if (TryRestoreCache(req.load_cache(), act.seqno(), act.seqno(SeqType::kCache))
|
||||
!= kSuccess) continue;
|
||||
}
|
||||
// if requested load cache, then mission complete
|
||||
if (req.load_cache()) return true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// assert no req with load cache set goes into seq catch up
|
||||
utils::Assert(!req.load_cache(), "load cache not interacte with rest states");
|
||||
|
||||
// no special flags, no checkpoint, check ack, load_check
|
||||
utils::Assert(act.min_seqno() != ActionSummary::kSpecialOp, "min seq bug");
|
||||
utils::Assert(act.seqno() != ActionSummary::kSpecialOp, "min seq bug");
|
||||
if (act.diff_seq()) {
|
||||
bool requester = req.min_seqno() == act.min_seqno();
|
||||
if (!CheckAndRecover(TryGetResult(buf, size, act.min_seqno(), requester))) continue;
|
||||
bool requester = req.seqno() == act.seqno();
|
||||
if (!CheckAndRecover(TryGetResult(buf, size, act.seqno(), requester))) continue;
|
||||
if (requester) return true;
|
||||
} else {
|
||||
// all the request is same,
|
||||
|
||||
@ -33,6 +33,23 @@ class AllreduceRobust : public AllreduceBase {
|
||||
* \param val parameter value
|
||||
*/
|
||||
virtual void SetParam(const char *name, const char *val);
|
||||
/*!
|
||||
* \brief perform immutable local bootstrap cache insertion
|
||||
* \param key unique cache key
|
||||
* \param buf buffer of allreduce/robust payload to copy
|
||||
* \param buflen total number of bytes
|
||||
* \return -1 if no recovery cache fetched otherwise 0
|
||||
*/
|
||||
int SetBootstrapCache(const std::string &key, const void *buf,
|
||||
const size_t type_nbytes, const size_t count);
|
||||
/*!
|
||||
* \brief perform bootstrap cache lookup if nodes in fault recovery
|
||||
* \param key unique cache key
|
||||
* \param buf buffer for recv allreduce/robust payload
|
||||
* \param buflen total number of bytes
|
||||
*/
|
||||
int GetBootstrapCache(const std::string &key, void *buf, const size_t type_nbytes,
|
||||
const size_t count, const bool byref = false);
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
@ -44,20 +61,37 @@ class AllreduceRobust : public AllreduceBase {
|
||||
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
* \param is_bootstrap if this allreduce is needed to bootstrap filed node
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun = NULL,
|
||||
void *prepare_arg = NULL);
|
||||
void *prepare_arg = NULL,
|
||||
bool is_bootstrap = false,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
/*!
|
||||
* \brief broadcast data from root to all nodes
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param size the size of the data to be broadcasted
|
||||
* \param root the root worker id to broadcast the data
|
||||
* \param is_bootstrap if this broadcast is needed to bootstrap filed node
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root);
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||
bool is_bootstrap = false,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
/*!
|
||||
* \brief load latest check point
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
@ -155,6 +189,13 @@ class AllreduceRobust : public AllreduceBase {
|
||||
/*! \brief current node only helps to pass data around */
|
||||
kPassData = 2
|
||||
};
|
||||
|
||||
enum SeqType {
|
||||
/*! \brief apply to rabit seq code */
|
||||
kSeq = 0,
|
||||
/*! \brief apply to rabit cache seq code */
|
||||
kCache = 1
|
||||
};
|
||||
/*!
|
||||
* \brief summary of actions proposed in all nodes
|
||||
* this data structure is used to make consensus decision
|
||||
@ -162,11 +203,11 @@ class AllreduceRobust : public AllreduceBase {
|
||||
*/
|
||||
struct ActionSummary {
|
||||
// maximumly allowed sequence id
|
||||
static const int kSpecialOp = (1 << 26);
|
||||
static const u_int32_t kSpecialOp = (1 << 26);
|
||||
// special sequence number for local state checkpoint
|
||||
static const int kLocalCheckPoint = (1 << 26) - 2;
|
||||
static const u_int32_t kLocalCheckPoint = (1 << 26) - 2;
|
||||
// special sequnce number for local state checkpoint ack signal
|
||||
static const int kLocalCheckAck = (1 << 26) - 1;
|
||||
static const u_int32_t kLocalCheckAck = (1 << 26) - 1;
|
||||
//---------------------------------------------
|
||||
// The following are bit mask of flag used in
|
||||
//----------------------------------------------
|
||||
@ -181,35 +222,59 @@ class AllreduceRobust : public AllreduceBase {
|
||||
// this means we want to do recover execution of the lower sequence
|
||||
// action instead of normal execution
|
||||
static const int kDiffSeq = 8;
|
||||
// there are nodes request load cache
|
||||
static const int kLoadBootstrapCache = 16;
|
||||
// constructor
|
||||
ActionSummary(void) {}
|
||||
// constructor of action
|
||||
explicit ActionSummary(int flag, int minseqno = kSpecialOp) {
|
||||
seqcode = (minseqno << 4) | flag;
|
||||
explicit ActionSummary(int seqno_flag, int cache_flag = 0,
|
||||
u_int32_t minseqno = kSpecialOp, u_int32_t maxseqno = kSpecialOp) {
|
||||
seqcode = (minseqno << 5) | seqno_flag;
|
||||
maxseqcode = (maxseqno << 5) | cache_flag;
|
||||
}
|
||||
// minimum number of all operations
|
||||
inline int min_seqno(void) const {
|
||||
return seqcode >> 4;
|
||||
// minimum number of all operations by default
|
||||
// maximum number of all cache operations otherwise
|
||||
inline u_int32_t seqno(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
||||
return code >> 5;
|
||||
}
|
||||
// whether the operation set contains a load_check
|
||||
inline bool load_check(void) const {
|
||||
return (seqcode & kLoadCheck) != 0;
|
||||
inline bool load_check(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
||||
return (code & kLoadCheck) != 0;
|
||||
}
|
||||
// whether the operation set contains a load_cache
|
||||
inline bool load_cache(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
||||
return (code & kLoadBootstrapCache) != 0;
|
||||
}
|
||||
// whether the operation set contains a check point
|
||||
inline bool check_point(void) const {
|
||||
return (seqcode & kCheckPoint) != 0;
|
||||
inline bool check_point(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
||||
return (code & kCheckPoint) != 0;
|
||||
}
|
||||
// whether the operation set contains a check ack
|
||||
inline bool check_ack(void) const {
|
||||
return (seqcode & kCheckAck) != 0;
|
||||
inline bool check_ack(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
||||
return (code & kCheckAck) != 0;
|
||||
}
|
||||
// whether the operation set contains different sequence number
|
||||
inline bool diff_seq(void) const {
|
||||
return (seqcode & kDiffSeq) != 0;
|
||||
inline bool diff_seq(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
||||
return (code & kDiffSeq) != 0;
|
||||
}
|
||||
// returns the operation flag of the result
|
||||
inline int flag(void) const {
|
||||
return seqcode & 15;
|
||||
inline int flag(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
||||
return code & 31;
|
||||
}
|
||||
// print flags in user friendly way
|
||||
inline void print_flags(int rank, std::string prefix ) {
|
||||
utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|%d|\n",
|
||||
rank, prefix.c_str(),
|
||||
seqno(), check_point(), check_ack(), load_cache(),
|
||||
diff_seq(), seqno(SeqType::kCache), load_cache(SeqType::kCache),
|
||||
diff_seq(SeqType::kCache));
|
||||
}
|
||||
// reducer for Allreduce, get the result ActionSummary from all nodes
|
||||
inline static void Reducer(const void *src_, void *dst_,
|
||||
@ -217,24 +282,31 @@ class AllreduceRobust : public AllreduceBase {
|
||||
const ActionSummary *src = (const ActionSummary*)src_;
|
||||
ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
int src_seqno = src[i].min_seqno();
|
||||
int dst_seqno = dst[i].min_seqno();
|
||||
int flag = src[i].flag() | dst[i].flag();
|
||||
if (src_seqno == dst_seqno) {
|
||||
dst[i] = ActionSummary(flag, src_seqno);
|
||||
} else {
|
||||
dst[i] = ActionSummary(flag | kDiffSeq,
|
||||
std::min(src_seqno, dst_seqno));
|
||||
}
|
||||
u_int32_t min_seqno = std::min(src[i].seqno(), dst[i].seqno());
|
||||
u_int32_t max_seqno = std::max(src[i].seqno(SeqType::kCache),
|
||||
dst[i].seqno(SeqType::kCache));
|
||||
int action_flag = src[i].flag() | dst[i].flag();
|
||||
// if any node is not requester set to 0 otherwise 1
|
||||
int role_flag = src[i].flag(SeqType::kCache) & dst[i].flag(SeqType::kCache);
|
||||
// if seqno is different in src and destination
|
||||
int seq_diff_flag = src[i].seqno() != dst[i].seqno() ? kDiffSeq : 0;
|
||||
// if cache seqno is different in src and destination
|
||||
int cache_diff_flag =
|
||||
src[i].seqno(SeqType::kCache) != dst[i].seqno(SeqType::kCache) ? kDiffSeq : 0;
|
||||
// apply or to both seq diff flag as well as cache seq diff flag
|
||||
dst[i] = ActionSummary(action_flag | seq_diff_flag,
|
||||
role_flag | cache_diff_flag, min_seqno, max_seqno);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// internel sequence code
|
||||
int seqcode;
|
||||
// internel sequence code min of rabit seqno
|
||||
u_int32_t seqcode;
|
||||
// internal sequence code max of cache seqno
|
||||
u_int32_t maxseqcode;
|
||||
};
|
||||
/*! \brief data structure to remember result of Bcast and Allreduce calls */
|
||||
class ResultBuffer {
|
||||
/*! \brief data structure to remember result of Bcast and Allreduce calls*/
|
||||
class ResultBuffer{
|
||||
public:
|
||||
// constructor
|
||||
ResultBuffer(void) {
|
||||
@ -251,6 +323,7 @@ class AllreduceRobust : public AllreduceBase {
|
||||
size_t size = type_nbytes * count;
|
||||
size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t);
|
||||
utils::Assert(nhop != 0, "cannot allocate 0 size memory");
|
||||
// allocate addational nhop buffer size
|
||||
data_.resize(rptr_.back() + nhop);
|
||||
return BeginPtr(data_) + rptr_.back();
|
||||
}
|
||||
@ -362,7 +435,13 @@ class AllreduceRobust : public AllreduceBase {
|
||||
* - false means this is the lastest action that has not yet been executed, need to execute the action
|
||||
*/
|
||||
bool RecoverExec(void *buf, size_t size, int flag,
|
||||
int seqno = ActionSummary::kSpecialOp);
|
||||
int seqno = ActionSummary::kSpecialOp,
|
||||
int cacheseqno = ActionSummary::kSpecialOp,
|
||||
#ifdef __linux__
|
||||
const char* caller = __builtin_FUNCTION());
|
||||
#else
|
||||
const char* caller = "N/A");
|
||||
#endif
|
||||
/*!
|
||||
* \brief try to load check point
|
||||
*
|
||||
@ -375,6 +454,19 @@ class AllreduceRobust : public AllreduceBase {
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryLoadCheckPoint(bool requester);
|
||||
|
||||
/*!
|
||||
* \brief try to load cache
|
||||
*
|
||||
* This is a collaborative function called by all nodes
|
||||
* only the nodes with requester set to true really needs to load the check point
|
||||
* other nodes acts as collaborative roles to complete this request
|
||||
* \param requester whether current node is the requester
|
||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryRestoreCache(bool requester, const int min_seq = ActionSummary::kSpecialOp,
|
||||
const int max_seq = ActionSummary::kSpecialOp);
|
||||
/*!
|
||||
* \brief try to get the result of operation specified by seqno
|
||||
*
|
||||
@ -519,6 +611,12 @@ o * the input state must exactly one saved state(local state of current node)
|
||||
int result_buffer_round;
|
||||
// result buffer of all reduce
|
||||
ResultBuffer resbuf;
|
||||
// current cached allreduce/braodcast sequence number
|
||||
int cur_cache_seq;
|
||||
// result buffer of cached all reduce
|
||||
ResultBuffer cachebuf;
|
||||
// key of each cache entry
|
||||
ResultBuffer lookupbuf;
|
||||
// last check point global model
|
||||
std::string global_checkpoint;
|
||||
// lazy checkpoint of global model
|
||||
|
||||
@ -92,9 +92,13 @@ void Allreduce_(void *sendrecvbuf,
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
|
||||
red, prepare_fun, prepare_arg);
|
||||
void *prepare_arg,
|
||||
bool is_bootstrap,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun,
|
||||
prepare_arg, is_bootstrap, _file, _line, _caller);
|
||||
}
|
||||
|
||||
// code for reduce handle
|
||||
@ -116,10 +120,15 @@ void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
||||
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||
size_t type_nbytes, size_t count,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
void *prepare_arg,
|
||||
bool is_bootstrap,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
utils::Assert(redfunc_ != NULL, "must intialize handle to call AllReduce");
|
||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
|
||||
redfunc_, prepare_fun, prepare_arg);
|
||||
redfunc_, prepare_fun, prepare_arg,
|
||||
is_bootstrap, _file, _line, _caller);
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
|
||||
@ -30,11 +30,17 @@ class EmptyEngine : public IEngine {
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
void *prepare_arg,
|
||||
bool is_bootstrap,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
utils::Error("EmptyEngine:: Allreduce is not supported,"\
|
||||
"use Allreduce_ instead");
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root,
|
||||
bool is_bootstrap, const char* _file,
|
||||
const int _line, const char* _caller) {
|
||||
}
|
||||
virtual void InitAfterException(void) {
|
||||
utils::Error("EmptyEngine is not fault tolerant");
|
||||
@ -102,7 +108,11 @@ void Allreduce_(void *sendrecvbuf,
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
void *prepare_arg,
|
||||
bool is_bootstrap,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
}
|
||||
|
||||
@ -118,7 +128,11 @@ void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {}
|
||||
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||
size_t type_nbytes, size_t count,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
void *prepare_arg,
|
||||
bool is_bootstrap,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
}
|
||||
} // namespace engine
|
||||
|
||||
@ -32,11 +32,17 @@ class MPIEngine : public IEngine {
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
void *prepare_arg,
|
||||
bool is_bootstrap,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
utils::Error("MPIEngine:: Allreduce is not supported,"\
|
||||
"use Allreduce_ instead");
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root,
|
||||
bool is_bootstrap, const char* _file, const int _line,
|
||||
const char* _caller) {
|
||||
MPI::COMM_WORLD.Bcast(sendrecvbuf_, size, MPI::CHAR, root);
|
||||
}
|
||||
virtual void InitAfterException(void) {
|
||||
@ -153,7 +159,11 @@ void Allreduce_(void *sendrecvbuf,
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
void *prepare_arg,
|
||||
bool is_bootstrap,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf,
|
||||
count, GetType(dtype), GetOp(op));
|
||||
@ -201,7 +211,11 @@ void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
||||
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||
size_t type_nbytes, size_t count,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
void *prepare_arg,
|
||||
bool is_bootstrap,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
utils::Assert(handle_ != NULL, "must intialize handle to call AllReduce");
|
||||
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle_);
|
||||
MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype_);
|
||||
|
||||
@ -1,26 +1,38 @@
|
||||
RABIT_BUILD_DMLC = 0
|
||||
|
||||
ifeq ($(RABIT_BUILD_DMLC),1)
|
||||
DMLC=../dmlc-core
|
||||
else
|
||||
DMLC=../../dmlc-core
|
||||
endif
|
||||
|
||||
MPICXX=../mpich/bin/mpicxx
|
||||
export LDFLAGS= -L../lib -pthread -lm
|
||||
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../include -I ../dmlc-core/include -std=c++11
|
||||
export CFLAGS = -Wall -O3 -Wno-unknown-pragmas
|
||||
|
||||
OS := $(shell uname)
|
||||
export CC = gcc
|
||||
export CXX = g++
|
||||
|
||||
ifeq ($(OS), Darwin)
|
||||
ifndef CC
|
||||
export CC = $(if $(shell which clang), clang, gcc)
|
||||
endif
|
||||
ifndef CXX
|
||||
export CXX = $(if $(shell which clang++), clang++, g++)
|
||||
endif
|
||||
|
||||
#----------------------------
|
||||
# Settings for power and arm arch
|
||||
#----------------------------
|
||||
ARCH := $(shell uname -a)
|
||||
ifneq (,$(filter $(ARCH), armv6l armv7l powerpc64le ppc64le aarch64))
|
||||
CFLAGS += -march=native
|
||||
else
|
||||
ifndef CC
|
||||
export CC = gcc
|
||||
endif
|
||||
ifndef CXX
|
||||
export CXX = g++
|
||||
endif
|
||||
LDFLAGS += -lrt
|
||||
CFLAGS += -msse2
|
||||
endif
|
||||
|
||||
ifndef WITH_FPIC
|
||||
WITH_FPIC = 1
|
||||
endif
|
||||
ifeq ($(WITH_FPIC), 1)
|
||||
CFLAGS += -fPIC
|
||||
endif
|
||||
|
||||
CFLAGS += -I../include -I $(DMLC)/include -std=c++11
|
||||
|
||||
# specify tensor path
|
||||
BIN = speed_test model_recover local_recover lazy_recover
|
||||
OBJ = $(RABIT_OBJ) speed_test.o model_recover.o local_recover.o lazy_recover.o
|
||||
|
||||
@ -94,6 +94,11 @@ int main(int argc, char *argv[]) {
|
||||
int rank = rabit::GetRank();
|
||||
int nproc = rabit::GetWorldSize();
|
||||
std::string name = rabit::GetProcessorName();
|
||||
|
||||
int max_rank = rank;
|
||||
rabit::Allreduce<op::Max>(&max_rank, sizeof(int), NULL, NULL, true);
|
||||
utils::Check(max_rank == nproc - 1, "max rank is world size-1");
|
||||
|
||||
Model model;
|
||||
srand(0);
|
||||
int ntrial = 0;
|
||||
@ -115,6 +120,7 @@ int main(int argc, char *argv[]) {
|
||||
TestBcast(n, i, ntrial, r);
|
||||
}
|
||||
printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
|
||||
|
||||
TestSum(&model, ntrial, r);
|
||||
printf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
|
||||
rabit::CheckPoint(&model);
|
||||
|
||||
24
test/test.mk
24
test/test.mk
@ -1,3 +1,11 @@
|
||||
RABIT_BUILD_DMLC = 0
|
||||
|
||||
ifeq ($(RABIT_BUILD_DMLC),1)
|
||||
DMLC=../dmlc-core
|
||||
else
|
||||
DMLC=../../dmlc-core
|
||||
endif
|
||||
|
||||
# this is a makefile used to show testcases of rabit
|
||||
.PHONY: all
|
||||
|
||||
@ -5,25 +13,25 @@ all: model_recover_10_10k model_recover_10_10k_die_same model_recover_10_10k_di
|
||||
|
||||
# this experiment test recovery with actually process exit, use keepalive to keep program alive
|
||||
model_recover_10_10k:
|
||||
../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0
|
||||
$(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 rabit_bootstrap_cache=-1 rabit_debug=1 rabit_reduce_ring_mincount=1
|
||||
|
||||
model_recover_10_10k_die_same:
|
||||
../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0
|
||||
$(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 rabit_bootstrap_cache=1
|
||||
|
||||
model_recover_10_10k_die_hard:
|
||||
../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0
|
||||
$(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0 rabit_bootstrap_cache=1
|
||||
|
||||
local_recover_10_10k:
|
||||
../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 local_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=1,1,1,1
|
||||
$(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 local_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=1,1,1,1
|
||||
|
||||
pylocal_recover_10_10k:
|
||||
../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 local_recover.py 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=1,1,1,1
|
||||
$(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 local_recover.py 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=1,1,1,1
|
||||
|
||||
lazy_recover_10_10k_die_hard:
|
||||
../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0
|
||||
$(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0
|
||||
|
||||
lazy_recover_10_10k_die_same:
|
||||
../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0
|
||||
$(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0
|
||||
|
||||
ringallreduce_10_10k:
|
||||
../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 rabit_reduce_ring_mincount=10
|
||||
$(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 rabit_reduce_ring_mincount=10
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user