|
|
|
|
@@ -1,12 +1,15 @@
|
|
|
|
|
/*!
|
|
|
|
|
* Copyright by Contributors
|
|
|
|
|
* \file rabit-inl.h
|
|
|
|
|
* \brief implementation of inline template function for rabit interface
|
|
|
|
|
*
|
|
|
|
|
* \author Tianqi Chen
|
|
|
|
|
*/
|
|
|
|
|
#ifndef RABIT_RABIT_INL_H
|
|
|
|
|
#define RABIT_RABIT_INL_H
|
|
|
|
|
#ifndef RABIT_RABIT_INL_H_
|
|
|
|
|
#define RABIT_RABIT_INL_H_
|
|
|
|
|
// use engine for implementation
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "./io.h"
|
|
|
|
|
#include "./utils.h"
|
|
|
|
|
#include "../rabit.h"
|
|
|
|
|
@@ -30,15 +33,15 @@ inline DataType GetType<int>(void) {
|
|
|
|
|
return kInt;
|
|
|
|
|
}
|
|
|
|
|
template<>
|
|
|
|
|
inline DataType GetType<unsigned int>(void) {
|
|
|
|
|
inline DataType GetType<unsigned int>(void) { // NOLINT(*)
|
|
|
|
|
return kUInt;
|
|
|
|
|
}
|
|
|
|
|
template<>
|
|
|
|
|
inline DataType GetType<long>(void) {
|
|
|
|
|
inline DataType GetType<long>(void) { // NOLINT(*)
|
|
|
|
|
return kLong;
|
|
|
|
|
}
|
|
|
|
|
template<>
|
|
|
|
|
inline DataType GetType<unsigned long>(void) {
|
|
|
|
|
inline DataType GetType<unsigned long>(void) { // NOLINT(*)
|
|
|
|
|
return kULong;
|
|
|
|
|
}
|
|
|
|
|
template<>
|
|
|
|
|
@@ -50,54 +53,54 @@ inline DataType GetType<double>(void) {
|
|
|
|
|
return kDouble;
|
|
|
|
|
}
|
|
|
|
|
template<>
|
|
|
|
|
inline DataType GetType<long long>(void) {
|
|
|
|
|
return kLongLong;
|
|
|
|
|
inline DataType GetType<long long>(void) { // NOLINT(*)
|
|
|
|
|
return kLongLong;
|
|
|
|
|
}
|
|
|
|
|
template<>
|
|
|
|
|
inline DataType GetType<unsigned long long>(void) {
|
|
|
|
|
return kULongLong;
|
|
|
|
|
inline DataType GetType<unsigned long long>(void) { // NOLINT(*)
|
|
|
|
|
return kULongLong;
|
|
|
|
|
}
|
|
|
|
|
} // namespace mpi
|
|
|
|
|
} // namespace engine
|
|
|
|
|
|
|
|
|
|
namespace op {
|
|
|
|
|
struct Max {
|
|
|
|
|
const static engine::mpi::OpType kType = engine::mpi::kMax;
|
|
|
|
|
static const engine::mpi::OpType kType = engine::mpi::kMax;
|
|
|
|
|
template<typename DType>
|
|
|
|
|
inline static void Reduce(DType &dst, const DType &src) {
|
|
|
|
|
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
|
|
|
|
|
if (dst < src) dst = src;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
struct Min {
|
|
|
|
|
const static engine::mpi::OpType kType = engine::mpi::kMin;
|
|
|
|
|
static const engine::mpi::OpType kType = engine::mpi::kMin;
|
|
|
|
|
template<typename DType>
|
|
|
|
|
inline static void Reduce(DType &dst, const DType &src) {
|
|
|
|
|
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
|
|
|
|
|
if (dst > src) dst = src;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
struct Sum {
|
|
|
|
|
const static engine::mpi::OpType kType = engine::mpi::kSum;
|
|
|
|
|
static const engine::mpi::OpType kType = engine::mpi::kSum;
|
|
|
|
|
template<typename DType>
|
|
|
|
|
inline static void Reduce(DType &dst, const DType &src) {
|
|
|
|
|
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
|
|
|
|
|
dst += src;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
struct BitOR {
|
|
|
|
|
const static engine::mpi::OpType kType = engine::mpi::kBitwiseOR;
|
|
|
|
|
static const engine::mpi::OpType kType = engine::mpi::kBitwiseOR;
|
|
|
|
|
template<typename DType>
|
|
|
|
|
inline static void Reduce(DType &dst, const DType &src) {
|
|
|
|
|
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
|
|
|
|
|
dst |= src;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
template<typename OP, typename DType>
|
|
|
|
|
inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
|
|
|
|
|
const DType *src = (const DType*)src_;
|
|
|
|
|
DType *dst = (DType*)dst_;
|
|
|
|
|
DType *dst = (DType*)dst_; // NOLINT(*)
|
|
|
|
|
for (int i = 0; i < len; ++i) {
|
|
|
|
|
OP::Reduce(dst[i], src[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace op
|
|
|
|
|
} // namespace op
|
|
|
|
|
|
|
|
|
|
// intialize the rabit engine
|
|
|
|
|
inline void Init(int argc, char *argv[]) {
|
|
|
|
|
@@ -152,23 +155,23 @@ inline void Broadcast(std::string *sendrecv_data, int root) {
|
|
|
|
|
// perform inplace Allreduce
|
|
|
|
|
template<typename OP, typename DType>
|
|
|
|
|
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
|
|
|
|
void (*prepare_fun)(void *arg),
|
|
|
|
|
void (*prepare_fun)(void *arg),
|
|
|
|
|
void *prepare_arg) {
|
|
|
|
|
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>,
|
|
|
|
|
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
|
|
|
|
|
engine::mpi::GetType<DType>(), OP::kType, prepare_fun, prepare_arg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// C++11 support for lambda prepare function
|
|
|
|
|
#if __cplusplus >= 201103L
|
|
|
|
|
#if DMLC_USE_CXX11
|
|
|
|
|
inline void InvokeLambda_(void *fun) {
|
|
|
|
|
(*static_cast<std::function<void()>*>(fun))();
|
|
|
|
|
}
|
|
|
|
|
template<typename OP, typename DType>
|
|
|
|
|
inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> prepare_fun) {
|
|
|
|
|
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>,
|
|
|
|
|
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
|
|
|
|
|
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda_, &prepare_fun);
|
|
|
|
|
}
|
|
|
|
|
#endif // C++11
|
|
|
|
|
#endif // C++11
|
|
|
|
|
|
|
|
|
|
// print message to the tracker
|
|
|
|
|
inline void TrackerPrint(const std::string &msg) {
|
|
|
|
|
@@ -223,15 +226,16 @@ inline void ReducerSafe_(const void *src_, void *dst_, int len_, const MPI::Data
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// function to perform reduction for Reducer
|
|
|
|
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
|
|
|
|
inline void ReducerAlign_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
|
|
|
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
|
|
|
|
|
inline void ReducerAlign_(const void *src_, void *dst_,
|
|
|
|
|
int len_, const MPI::Datatype &dtype) {
|
|
|
|
|
const DType *psrc = reinterpret_cast<const DType*>(src_);
|
|
|
|
|
DType *pdst = reinterpret_cast<DType*>(dst_);
|
|
|
|
|
for (int i = 0; i < len_; ++i) {
|
|
|
|
|
freduce(pdst[i], psrc[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
|
|
|
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
|
|
|
|
|
inline Reducer<DType, freduce>::Reducer(void) {
|
|
|
|
|
// it is safe to directly use handle for aligned data types
|
|
|
|
|
if (sizeof(DType) == 8 || sizeof(DType) == 4 || sizeof(DType) == 1) {
|
|
|
|
|
@@ -240,7 +244,7 @@ inline Reducer<DType, freduce>::Reducer(void) {
|
|
|
|
|
this->handle_.Init(ReducerSafe_<DType, freduce>, sizeof(DType));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
|
|
|
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
|
|
|
|
|
inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
|
|
|
|
|
void (*prepare_fun)(void *arg),
|
|
|
|
|
void *prepare_arg) {
|
|
|
|
|
@@ -248,13 +252,14 @@ inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
|
|
|
|
|
}
|
|
|
|
|
// function to perform reduction for SerializeReducer
|
|
|
|
|
template<typename DType>
|
|
|
|
|
inline void SerializeReducerFunc_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
|
|
|
|
inline void SerializeReducerFunc_(const void *src_, void *dst_,
|
|
|
|
|
int len_, const MPI::Datatype &dtype) {
|
|
|
|
|
int nbytes = engine::ReduceHandle::TypeSize(dtype);
|
|
|
|
|
// temp space
|
|
|
|
|
DType tsrc, tdst;
|
|
|
|
|
for (int i = 0; i < len_; ++i) {
|
|
|
|
|
utils::MemoryFixSizeBuffer fsrc((char*)(src_) + i * nbytes, nbytes);
|
|
|
|
|
utils::MemoryFixSizeBuffer fdst((char*)(dst_) + i * nbytes, nbytes);
|
|
|
|
|
utils::MemoryFixSizeBuffer fsrc((char*)(src_) + i * nbytes, nbytes); // NOLINT(*)
|
|
|
|
|
utils::MemoryFixSizeBuffer fdst((char*)(dst_) + i * nbytes, nbytes); // NOLINT(*)
|
|
|
|
|
tsrc.Load(fsrc);
|
|
|
|
|
tdst.Load(fdst);
|
|
|
|
|
// govern const check
|
|
|
|
|
@@ -296,8 +301,8 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
|
|
|
|
// setup closure
|
|
|
|
|
SerializeReduceClosure<DType> c;
|
|
|
|
|
c.sendrecvobj = sendrecvobj; c.max_nbyte = max_nbyte; c.count = count;
|
|
|
|
|
c.prepare_fun = prepare_fun; c.prepare_arg = prepare_arg; c.p_buffer = &buffer_;
|
|
|
|
|
// invoke here
|
|
|
|
|
c.prepare_fun = prepare_fun; c.prepare_arg = prepare_arg; c.p_buffer = &buffer_;
|
|
|
|
|
// invoke here
|
|
|
|
|
handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count,
|
|
|
|
|
SerializeReduceClosure<DType>::Invoke, &c);
|
|
|
|
|
for (size_t i = 0; i < count; ++i) {
|
|
|
|
|
@@ -306,8 +311,8 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if __cplusplus >= 201103L
|
|
|
|
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
|
|
|
|
#if DMLC_USE_CXX11
|
|
|
|
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)g
|
|
|
|
|
inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
|
|
|
|
|
std::function<void()> prepare_fun) {
|
|
|
|
|
this->Allreduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun);
|
|
|
|
|
@@ -320,4 +325,4 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
} // namespace rabit
|
|
|
|
|
#endif
|
|
|
|
|
#endif // RABIT_RABIT_INL_H_
|
|
|
|
|
|