Cleanup RABIT. (#6290)

* Remove recovery and MPI speed tests.
* Remove readme.
* Remove Python binding.
* Add checks in C API.
This commit is contained in:
Jiaming Yuan
2020-10-27 08:48:22 +08:00
committed by GitHub
parent 8e0f5a6fc7
commit b180223d18
40 changed files with 113 additions and 1875 deletions

View File

@@ -41,7 +41,7 @@ RABIT_DLL bool RabitInit(int argc, char *argv[]);
* call this function after you finished all jobs.
* \return true if rabit is initialized successfully otherwise false
*/
RABIT_DLL bool RabitFinalize(void);
RABIT_DLL int RabitFinalize(void);
/*!
* \brief get rank of previous process in ring topology
@@ -91,8 +91,7 @@ RABIT_DLL void RabitGetProcessorName(char *out_name,
* \param size the size of the data
* \param root the root of process
*/
RABIT_DLL void RabitBroadcast(void *sendrecv_data,
rbt_ulong size, int root);
RABIT_DLL int RabitBroadcast(void *sendrecv_data, rbt_ulong size, int root);
/*!
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
@@ -110,12 +109,9 @@ RABIT_DLL void RabitBroadcast(void *sendrecv_data,
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
RABIT_DLL void RabitAllgather(void *sendrecvbuf,
size_t total_size,
size_t beginIndex,
size_t size_node_slice,
size_t size_prev_slice,
int enum_dtype);
RABIT_DLL int RabitAllgather(void *sendrecvbuf, size_t total_size,
size_t beginIndex, size_t size_node_slice,
size_t size_prev_slice, int enum_dtype);
/*!
* \brief perform in-place allreduce, on sendrecvbuf
@@ -133,14 +129,11 @@ RABIT_DLL void RabitAllgather(void *sendrecvbuf,
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
*/
RABIT_DLL void RabitAllreduce(void *sendrecvbuf,
size_t count,
int enum_dtype,
int enum_op,
void (*prepare_fun)(void *arg),
void *prepare_arg);
* \param prepare_arg argument used to passed into the lazy preprocessing function
*/
RABIT_DLL int RabitAllreduce(void *sendrecvbuf, size_t count, int enum_dtype,
int enum_op, void (*prepare_fun)(void *arg),
void *prepare_arg);
/*!
* \brief load latest check point

View File

@@ -9,16 +9,6 @@
#include <string>
#include "rabit/serializable.h"
#if (defined(__GNUC__) && !defined(__clang__))
#define _FILE __builtin_FILE()
#define _LINE __builtin_LINE()
#define _CALLER __builtin_FUNCTION()
#else
#define _FILE "N/A"
#define _LINE -1
#define _CALLER "N/A"
#endif // (defined(__GNUC__) && !defined(__clang__))
namespace MPI { // NOLINT
/*! \brief MPI data type just to be compatible with MPI reduce function*/
class Datatype;
@@ -65,18 +55,12 @@ class IEngine {
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Allgather(void *sendrecvbuf,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) = 0;
size_t size_prev_slice) = 0;
/*!
* \brief performs in-place Allreduce, on sendrecvbuf
* this function is NOT thread-safe
@@ -88,38 +72,20 @@ class IEngine {
* will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to pass into the lazy preprocessing function
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun = nullptr,
void *prepare_arg = nullptr,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) = 0;
void *prepare_arg = nullptr) = 0;
/*!
* \brief broadcasts data from root to every other node
* \param sendrecvbuf_ buffer for both sending and receiving data
* \param size the size of the data to be broadcasted
* \param root the root worker id to broadcast the data
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) = 0;
/*!
* \brief explicitly re-initialize everything before calling LoadCheckPoint
* call this function when IEngine throws an exception,
* this function should only be used for test purposes
*/
virtual void InitAfterException() = 0;
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) = 0;
/*!
* \brief loads the latest check point
* \param global_model pointer to the globally shared model/state
@@ -250,18 +216,12 @@ enum DataType {
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
void Allgather(void* sendrecvbuf,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
size_t size_prev_slice);
/*!
* \brief perform in-place Allreduce, on sendrecvbuf
* this is an internal function used by rabit to be able to compile with MPI
@@ -276,9 +236,6 @@ void Allgather(void* sendrecvbuf,
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to pass into the lazy preprocessing function.
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
void Allreduce_(void *sendrecvbuf, // NOLINT
size_t type_nbytes,
@@ -287,10 +244,7 @@ void Allreduce_(void *sendrecvbuf, // NOLINT
mpi::DataType dtype,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun = nullptr,
void *prepare_arg = nullptr,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
void *prepare_arg = nullptr);
/*!
* \brief handle for customized reducer, used to handle customized reduce
* this class is mainly created for compatiblity issues with MPI's customized reduce
@@ -316,18 +270,13 @@ class ReduceHandle {
* will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to pass into the lazy preprocessing function
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
void Allreduce(void *sendrecvbuf,
size_t type_nbytes,
size_t count,
IEngine::PreprocFunction prepare_fun = nullptr,
void *prepare_arg = nullptr,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
void *prepare_arg = nullptr);
/*! \return the number of bytes occupied by the type */
static int TypeSize(const MPI::Datatype &dtype);

View File

@@ -131,40 +131,28 @@ inline std::string GetProcessorName() {
return engine::GetEngine()->GetHost();
}
// broadcast data to all other nodes from root
inline void Broadcast(void *sendrecv_data, size_t size, int root,
const char* _file,
const int _line,
const char* _caller) {
engine::GetEngine()->Broadcast(sendrecv_data, size, root,
_file, _line, _caller);
inline void Broadcast(void *sendrecv_data, size_t size, int root) {
engine::GetEngine()->Broadcast(sendrecv_data, size, root);
}
template<typename DType>
inline void Broadcast(std::vector<DType> *sendrecv_data, int root,
const char* _file,
const int _line,
const char* _caller) {
inline void Broadcast(std::vector<DType> *sendrecv_data, int root) {
size_t size = sendrecv_data->size();
Broadcast(&size, sizeof(size), root, _file, _line, _caller);
Broadcast(&size, sizeof(size), root);
if (sendrecv_data->size() != size) {
sendrecv_data->resize(size);
}
if (size != 0) {
Broadcast(&(*sendrecv_data)[0], size * sizeof(DType), root,
_file, _line, _caller);
Broadcast(&(*sendrecv_data)[0], size * sizeof(DType), root);
}
}
inline void Broadcast(std::string *sendrecv_data, int root,
const char* _file,
const int _line,
const char* _caller) {
inline void Broadcast(std::string *sendrecv_data, int root) {
size_t size = sendrecv_data->length();
Broadcast(&size, sizeof(size), root, _file, _line, _caller);
Broadcast(&size, sizeof(size), root);
if (sendrecv_data->length() != size) {
sendrecv_data->resize(size);
}
if (size != 0) {
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root,
_file, _line, _caller);
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root);
}
}
@@ -172,13 +160,9 @@ inline void Broadcast(std::string *sendrecv_data, int root,
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg),
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
void *prepare_arg) {
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
engine::mpi::GetType<DType>(), OP::kType, prepare_fun, prepare_arg,
_file, _line, _caller);
engine::mpi::GetType<DType>(), OP::kType, prepare_fun, prepare_arg);
}
// C++11 support for lambda prepare function
@@ -188,13 +172,9 @@ inline void InvokeLambda(void *fun) {
}
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count,
std::function<void()> prepare_fun,
const char* _file,
const int _line,
const char* _caller) {
std::function<void()> prepare_fun) {
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda, &prepare_fun,
_file, _line, _caller);
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda, &prepare_fun);
}
// Performs inplace Allgather
@@ -203,13 +183,10 @@ inline void Allgather(DType *sendrecvbuf,
size_t totalSize,
size_t beginIndex,
size_t sizeNodeSlice,
size_t sizePrevSlice,
const char* _file,
const int _line,
const char* _caller) {
size_t sizePrevSlice) {
engine::GetEngine()->Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType),
(beginIndex + sizeNodeSlice) * sizeof(DType),
sizePrevSlice * sizeof(DType), _file, _line, _caller);
sizePrevSlice * sizeof(DType));
}
#endif // C++11
@@ -289,12 +266,9 @@ inline Reducer<DType, freduce>::Reducer() {
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg),
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
void *prepare_arg) {
handle_.Allreduce(sendrecvbuf, sizeof(DType), count, prepare_fun,
prepare_arg, _file, _line, _caller);
prepare_arg);
}
// function to perform reduction for SerializeReducer
template<typename DType>
@@ -342,10 +316,7 @@ template<typename DType>
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
size_t max_nbyte, size_t count,
void (*prepare_fun)(void *arg),
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
void *prepare_arg) {
buffer_.resize(max_nbyte * count);
// setup closure
SerializeReduceClosure<DType> c;
@@ -353,34 +324,23 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
c.prepare_fun = prepare_fun; c.prepare_arg = prepare_arg; c.p_buffer = &buffer_;
// invoke here
handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count,
SerializeReduceClosure<DType>::Invoke, &c,
_file, _line, _caller);
SerializeReduceClosure<DType>::Invoke, &c);
for (size_t i = 0; i < count; ++i) {
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte);
sendrecvobj[i].Load(fs);
}
}
#if DMLC_USE_CXX11
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)g
inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
std::function<void()> prepare_fun,
const char* _file,
const int _line,
const char* _caller) {
this->Allreduce(sendrecvbuf, count, InvokeLambda, &prepare_fun,
_file, _line, _caller);
std::function<void()> prepare_fun) {
this->Allreduce(sendrecvbuf, count, InvokeLambda, &prepare_fun);
}
template<typename DType>
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
size_t max_nbytes, size_t count,
std::function<void()> prepare_fun,
const char* _file,
const int _line,
const char* _caller) {
this->Allreduce(sendrecvobj, max_nbytes, count, InvokeLambda, &prepare_fun,
_file, _line, _caller);
std::function<void()> prepare_fun) {
this->Allreduce(sendrecvobj, max_nbytes, count, InvokeLambda, &prepare_fun);
}
#endif // DMLC_USE_CXX11
} // namespace rabit
#endif // RABIT_INTERNAL_RABIT_INL_H_

View File

@@ -1,41 +0,0 @@
/*!
* Copyright (c) 2014-2019 by Contributors
* \file timer.h
* \brief This file defines the utils for timing
* \author Tianqi Chen, Nacho, Tianyi
*/
#ifndef RABIT_INTERNAL_TIMER_H_
#define RABIT_INTERNAL_TIMER_H_
#include <ctime>
#ifdef __MACH__
#include <mach/clock.h>
#include <mach/mach.h>
#endif // __MACH__
#include "./utils.h"
namespace rabit {
namespace utils {
/*!
* \brief return time in seconds, not cross platform, avoid to use this in most places
*/
inline double GetTime() {
#ifdef __MACH__
clock_serv_t cclock;
mach_timespec_t mts;
host_get_clock_service(mach_host_self(), CALENDAR_CLOCK, &cclock);
utils::Check(clock_get_time(cclock, &mts) == 0, "failed to get time");
mach_port_deallocate(mach_task_self(), cclock);
return static_cast<double>(mts.tv_sec) + static_cast<double>(mts.tv_nsec) * 1e-9;
#else
#if defined(__unix__) || defined(__linux__)
timespec ts;
utils::Check(clock_gettime(CLOCK_REALTIME, &ts) == 0, "failed to get time");
return static_cast<double>(ts.tv_sec) + static_cast<double>(ts.tv_nsec) * 1e-9;
#else
return static_cast<double>(time(NULL));
#endif // defined(__unix__) || defined(__linux__)
#endif // __MACH__
}
} // namespace utils
} // namespace rabit
#endif // RABIT_INTERNAL_TIMER_H_

View File

@@ -69,35 +69,10 @@ inline bool StringToBool(const char* s) {
return CompareStringsCaseInsensitive(s, "true") == 0 || atoi(s) != 0;
}
/*!
* \brief handling of Assert error, caused by inappropriate input
* \param msg error message
*/
inline void HandleAssertError(const char *msg) {
LOG(FATAL) << msg;
}
/*!
* \brief handling of Check error, caused by inappropriate input
* \param msg error message
*/
inline void HandleCheckError(const char *msg) {
LOG(FATAL) << msg;
}
inline void HandlePrint(const char *msg) {
printf("%s", msg);
}
inline void HandleLogInfo(const char *fmt, ...) {
std::string msg(kPrintBuffer, '\0');
va_list args;
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
fprintf(stdout, "%s", msg.c_str());
fflush(stdout);
}
/*! \brief printf, prints messages to the console */
inline void Printf(const char *fmt, ...) {
std::string msg(kPrintBuffer, '\0');
@@ -108,15 +83,6 @@ inline void Printf(const char *fmt, ...) {
HandlePrint(msg.c_str());
}
/*! \brief portable version of snprintf */
inline int SPrintf(char *buf, size_t size, const char *fmt, ...) {
va_list args;
va_start(args, fmt);
int ret = vsnprintf(buf, size, fmt, args);
va_end(args);
return ret;
}
/*! \brief assert a condition is true, use this to handle debug information */
inline void Assert(bool exp, const char *fmt, ...) {
if (!exp) {
@@ -125,7 +91,7 @@ inline void Assert(bool exp, const char *fmt, ...) {
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
HandleAssertError(msg.c_str());
LOG(FATAL) << msg;
}
}
@@ -137,7 +103,7 @@ inline void Check(bool exp, const char *fmt, ...) {
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
HandleCheckError(msg.c_str());
LOG(FATAL) << msg;
}
}
@@ -149,7 +115,7 @@ inline void Error(const char *fmt, ...) {
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
HandleCheckError(msg.c_str());
LOG(FATAL) << msg;
}
}
} // namespace utils
@@ -176,15 +142,6 @@ inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*)
return &vec[0];
}
}
/*! \brief get the beginning address of a vector */
template<typename T>
inline const T *BeginPtr(const std::vector<T> &vec) { // NOLINT(*)
if (vec.size() == 0) {
return nullptr;
} else {
return &vec[0];
}
}
inline char* BeginPtr(std::string &str) { // NOLINT(*)
if (str.length() == 0) return nullptr;
return &str[0];

View File

@@ -12,36 +12,7 @@
#define RABIT_RABIT_H_ // NOLINT(*)
#include <string>
#include <vector>
// whether or not use c++11 support
#ifndef DMLC_USE_CXX11
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || defined(_MSC_VER)
#define DMLC_USE_CXX11 1
#else
#define DMLC_USE_CXX11 (__cplusplus >= 201103L)
#endif // defined(__GXX_EXPERIMENTAL_CXX0X__) || defined(_MSC_VER)
#endif // DMLC_USE_CXX11
// keeps rabit api caller signature
#ifndef RABIT_API_CALLER_SIGNATURE
#define RABIT_API_CALLER_SIGNATURE
#if (defined(__GNUC__) && !defined(__clang__))
#define _FILE __builtin_FILE()
#define _LINE __builtin_LINE()
#define _CALLER __builtin_FUNCTION()
#else
#define _FILE "N/A"
#define _LINE -1
#define _CALLER "N/A"
#endif // (defined(__GNUC__) && !defined(__clang__))
#endif // RABIT_API_CALLER_SIGNATURE
// optionally support of lambda functions in C++11, if available
#if DMLC_USE_CXX11
#include <functional>
#endif // C++11
// engine definition of rabit, defines internal implementation
// to use rabit interface, there is no need to read engine.h
// rabit.h and serializable.h are enough to use the interface
@@ -135,31 +106,19 @@ inline void TrackerPrintf(const char *fmt, ...);
* \param sendrecv_data the pointer to the send/receive buffer,
* \param size the data size
* \param root the process root
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
inline void Broadcast(void *sendrecv_data, size_t size, int root,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
inline void Broadcast(void *sendrecv_data, size_t size, int root);
/*!
* \brief broadcasts an std::vector<DType> to every node from root
* \param sendrecv_data the pointer to send/receive vector,
* for the receiver, the vector does not need to be pre-allocated
* \param root the process root
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
* \tparam DType the data type stored in the vector, has to be a simple data type
* that can be directly transmitted by sending the sizeof(DType)
*/
template<typename DType>
inline void Broadcast(std::vector<DType> *sendrecv_data, int root,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
inline void Broadcast(std::vector<DType> *sendrecv_data, int root);
/*!
* \brief broadcasts a std::string to every node from the root
* \param sendrecv_data the pointer to the send/receive buffer,
@@ -169,10 +128,7 @@ inline void Broadcast(std::vector<DType> *sendrecv_data, int root,
* \param _caller caller function name used to generate unique cache key
* \param root the process root
*/
inline void Broadcast(std::string *sendrecv_data, int root,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
inline void Broadcast(std::string *sendrecv_data, int root);
/*!
* \brief performs in-place Allreduce on sendrecvbuf
* this function is NOT thread-safe
@@ -191,19 +147,13 @@ inline void Broadcast(std::string *sendrecv_data, int root,
* will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to pass into the lazy preprocessing function
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
* \tparam OP see namespace op, reduce operator
* \tparam DType data type
*/
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *) = nullptr,
void *prepare_arg = nullptr,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
void *prepare_arg = nullptr);
/*!
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
@@ -217,19 +167,13 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
template<typename DType>
inline void Allgather(DType *sendrecvbuf_,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
size_t size_prev_slice);
// C++11 support for lambda prepare function
#if DMLC_USE_CXX11
@@ -254,18 +198,12 @@ inline void Allgather(DType *sendrecvbuf_,
* \param prepare_fun Lazy lambda preprocessing function, prepare_fun() will be invoked
* by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
* \tparam OP see namespace op, reduce operator
* \tparam DType data type
*/
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count,
std::function<void()> prepare_fun,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
std::function<void()> prepare_fun);
#endif // C++11
/*!
* \brief loads the latest check point
@@ -361,31 +299,19 @@ class Reducer {
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to pass into the lazy preprocessing function
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *) = nullptr,
void *prepare_arg = nullptr,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
void *prepare_arg = nullptr);
#if DMLC_USE_CXX11
/*!
* \brief customized in-place all reduce operation, with lambda function as preprocessor
* \param sendrecvbuf pointer to the array of objects to be reduced
* \param count number of elements to be reduced
* \param prepare_fun lambda function executed to prepare the data, if necessary
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
inline void Allreduce(DType *sendrecvbuf, size_t count,
std::function<void()> prepare_fun,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
std::function<void()> prepare_fun);
#endif // DMLC_USE_CXX11
private:
@@ -416,17 +342,11 @@ class SerializeReducer {
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf.
* If the result of Allreduce can be recovered directly, then the prepare_func will NOT be called
* \param prepare_arg argument used to pass into the lazy preprocessing function
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
inline void Allreduce(DType *sendrecvobj,
size_t max_nbyte, size_t count,
void (*prepare_fun)(void *) = nullptr,
void *prepare_arg = nullptr,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
void *prepare_arg = nullptr);
// C++11 support for lambda prepare function
#if DMLC_USE_CXX11
/*!
@@ -436,16 +356,10 @@ class SerializeReducer {
* this includes budget limit for intermediate and final result
* \param count number of elements to be reduced
* \param prepare_fun lambda function executed to prepare the data, if necessary
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
inline void Allreduce(DType *sendrecvobj,
size_t max_nbyte, size_t count,
std::function<void()> prepare_fun,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
std::function<void()> prepare_fun);
#endif // DMLC_USE_CXX11
private: