enable support for lambda preprocessing function, and c++11
This commit is contained in:
parent
58331067f8
commit
1754fdbf4e
@ -64,11 +64,18 @@ class AllreduceBase : public IEngine {
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
*/
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer) {
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun = NULL,
|
||||
void *prepare_arg = NULL) {
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) == kSuccess,
|
||||
"Allreduce failed");
|
||||
}
|
||||
|
||||
@ -57,17 +57,24 @@ void AllreduceRobust::SetParam(const char *name, const char *val) {
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
*/
|
||||
void AllreduceRobust::Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer) {
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
|
||||
// now we are free to remove the last result, if any
|
||||
if (resbuf.LastSeqNo() != -1 &&
|
||||
(resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
|
||||
resbuf.DropLast();
|
||||
}
|
||||
if (!recovered && prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
void *temp = resbuf.AllocTemp(type_nbytes, count);
|
||||
while (true) {
|
||||
if (recovered) {
|
||||
|
||||
@ -35,11 +35,17 @@ class AllreduceRobust : public AllreduceBase {
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
*/
|
||||
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
*/
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer);
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun = NULL,
|
||||
void *prepare_arg = NULL);
|
||||
/*!
|
||||
* \brief broadcast data from root to all nodes
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
|
||||
@ -43,8 +43,10 @@ void Allreduce_(void *sendrecvbuf,
|
||||
size_t count,
|
||||
IEngine::ReduceFunction red,
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op) {
|
||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red);
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun, prepare_arg);
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
|
||||
26
src/engine.h
26
src/engine.h
@ -20,6 +20,12 @@ namespace engine {
|
||||
class IEngine {
|
||||
public:
|
||||
/*!
|
||||
* \brief Preprocessing function, that is called before AllReduce,
|
||||
* used to prepare the data used by AllReduce
|
||||
* \param arg additional possible argument used to invoke the preprocessor
|
||||
*/
|
||||
typedef void (PreprocFunction) (void *arg);
|
||||
/*!
|
||||
* \brief reduce function, the same form of MPI reduce function is used,
|
||||
* to be compatible with MPI interface
|
||||
* In all the functions, the memory is ensured to aligned to 64-bit
|
||||
@ -34,17 +40,23 @@ class IEngine {
|
||||
void *dst, int count,
|
||||
const MPI::Datatype &dtype);
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
*/
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer) = 0;
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun = NULL,
|
||||
void *prepare_arg = NULL) = 0;
|
||||
/*!
|
||||
* \brief broadcast data from root to all nodes
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
@ -145,13 +157,19 @@ enum DataType {
|
||||
* \param reducer reduce function
|
||||
* \param dtype the data type
|
||||
* \param op the reduce operator type
|
||||
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function *
|
||||
*/
|
||||
void Allreduce_(void *sendrecvbuf,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
IEngine::ReduceFunction red,
|
||||
IEngine::ReduceFunction red,
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op);
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun = NULL,
|
||||
void *prepare_arg = NULL);
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
#endif // RABIT_ENGINE_H
|
||||
|
||||
@ -23,7 +23,9 @@ class MPIEngine : public IEngine {
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer) {
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
utils::Error("MPIEngine:: Allreduce is not supported, use Allreduce_ instead");
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
|
||||
@ -110,7 +112,10 @@ void Allreduce_(void *sendrecvbuf,
|
||||
size_t count,
|
||||
IEngine::ReduceFunction red,
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op) {
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, GetType(dtype), GetOp(op));
|
||||
}
|
||||
} // namespace engine
|
||||
|
||||
@ -119,10 +119,25 @@ inline void Broadcast(std::string *sendrecv_data, int root) {
|
||||
|
||||
// perform inplace Allreduce
|
||||
template<typename OP, typename DType>
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count) {
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>,
|
||||
engine::mpi::GetType<DType>(), OP::kType);
|
||||
engine::mpi::GetType<DType>(), OP::kType, prepare_fun, prepare_arg);
|
||||
}
|
||||
|
||||
// C++11 support for lambda prepare function
|
||||
#if __cplusplus >= 201103L
|
||||
inline void InvokeLambda_(void *fun) {
|
||||
(*static_cast<std::function<void()>*>(fun))();
|
||||
}
|
||||
template<typename OP, typename DType>
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> prepare_fun) {
|
||||
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>,
|
||||
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda_, &prepare_fun);
|
||||
}
|
||||
#endif // C++11
|
||||
|
||||
// load latest check point
|
||||
inline int LoadCheckPoint(utils::ISerializable *global_model,
|
||||
utils::ISerializable *local_model) {
|
||||
|
||||
40
src/rabit.h
40
src/rabit.h
@ -10,6 +10,11 @@
|
||||
*/
|
||||
#include <string>
|
||||
#include <vector>
|
||||
// optionally support of lambda function in C++11, if available
|
||||
#if __cplusplus >= 201103L
|
||||
#include <functional>
|
||||
#endif // C++11
|
||||
// rabit headers
|
||||
#include "./io.h"
|
||||
#include "./engine.h"
|
||||
|
||||
@ -78,11 +83,44 @@ inline void Broadcast(std::string *sendrecv_data, int root);
|
||||
* ...
|
||||
* \param sendrecvbuf buffer for both sending and recving data
|
||||
* \param count number of elements to be reduced
|
||||
* \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
* \tparam OP see namespace op, reduce operator
|
||||
* \tparam DType type of data
|
||||
*/
|
||||
template<typename OP, typename DType>
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count);
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
void (*prepare_fun)(void *arg) = NULL,
|
||||
void *prepare_arg = NULL);
|
||||
|
||||
// C++11 support for lambda prepare function
|
||||
#if __cplusplus >= 201103L
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* with a prepare function specified by lambda function
|
||||
* Example Usage: the following code gives sum of the result
|
||||
* vector<int> data(10);
|
||||
* ...
|
||||
* Allreduce<op::Sum>(&data[0], data.size(), [&]() {
|
||||
* for (int i = 0; i < 10; ++i) {
|
||||
* data[i] = i;
|
||||
* }
|
||||
* });
|
||||
* ...
|
||||
* \param sendrecvbuf buffer for both sending and recving data
|
||||
* \param count number of elements to be reduced
|
||||
* \param prepare_func Lazy lambda preprocessing function, prepare_fun() will be invoked
|
||||
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \tparam OP see namespace op, reduce operator
|
||||
* \tparam DType type of data
|
||||
*/
|
||||
template<typename OP, typename DType>
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> prepare_fun);
|
||||
#endif // C++11
|
||||
|
||||
/*!
|
||||
* \brief load latest check point
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
|
||||
@ -2,7 +2,7 @@ export CC = gcc
|
||||
export CXX = g++
|
||||
export MPICXX = mpicxx
|
||||
export LDFLAGS= -pthread -lm -lrt
|
||||
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../src
|
||||
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../src -std=c++11
|
||||
|
||||
# specify tensor path
|
||||
BIN = speed_test test_model_recover test_local_recover
|
||||
|
||||
@ -50,14 +50,17 @@ class Model : public rabit::utils::ISerializable {
|
||||
inline void TestMax(Model *model, Model *local, int ntrial, int iter) {
|
||||
int rank = rabit::GetRank();
|
||||
int nproc = rabit::GetWorldSize();
|
||||
const int z = iter + 111;
|
||||
|
||||
const int z = iter + 111;
|
||||
std::vector<float> ndata(model->data.size());
|
||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||
ndata[i] = (i * (rank+1)) % z + local->data[i];
|
||||
}
|
||||
|
||||
test::CallBegin("Allreduce::Max", ntrial, iter);
|
||||
rabit::Allreduce<op::Max>(&ndata[0], ndata.size());
|
||||
rabit::Allreduce<op::Max>(&ndata[0], ndata.size(),
|
||||
[&]() {
|
||||
// use lambda expression to prepare the data
|
||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||
ndata[i] = (i * (rank+1)) % z + local->data[i];
|
||||
}
|
||||
});
|
||||
test::CallEnd("Allreduce::Max", ntrial, iter);
|
||||
|
||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user