diff --git a/src/allreduce_base.h b/src/allreduce_base.h index 4ef4a044e..bc7cc26c9 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -64,11 +64,18 @@ class AllreduceBase : public IEngine { * \param type_nbytes the unit number of bytes the type have * \param count number of elements to be reduced * \param reducer reduce function + * \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg) + * will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_. + * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called + * \param prepare_arg argument used to passed into the lazy preprocessing function */ virtual void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, - ReduceFunction reducer) { + ReduceFunction reducer, + PreprocFunction prepare_fun = NULL, + void *prepare_arg = NULL) { + if (prepare_fun != NULL) prepare_fun(prepare_arg); utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) == kSuccess, "Allreduce failed"); } diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index 828f57e60..7d339cf84 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -57,17 +57,24 @@ void AllreduceRobust::SetParam(const char *name, const char *val) { * \param type_nbytes the unit number of bytes the type have * \param count number of elements to be reduced * \param reducer reduce function + * \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg) + * will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_. + * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called + * \param prepare_arg argument used to passed into the lazy preprocessing function */ void AllreduceRobust::Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, - ReduceFunction reducer) { + ReduceFunction reducer, + PreprocFunction prepare_fun, + void *prepare_arg) { bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter); // now we are free to remove the last result, if any if (resbuf.LastSeqNo() != -1 && (resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) { resbuf.DropLast(); } + if (!recovered && prepare_fun != NULL) prepare_fun(prepare_arg); void *temp = resbuf.AllocTemp(type_nbytes, count); while (true) { if (recovered) { diff --git a/src/allreduce_robust.h b/src/allreduce_robust.h index 92c682b12..7888a66f1 100644 --- a/src/allreduce_robust.h +++ b/src/allreduce_robust.h @@ -35,11 +35,17 @@ class AllreduceRobust : public AllreduceBase { * \param type_nbytes the unit number of bytes the type have * \param count number of elements to be reduced * \param reducer reduce function - */ + * \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg) + * will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_. + * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called + * \param prepare_arg argument used to passed into the lazy preprocessing function + */ virtual void Allreduce(void *sendrecvbuf_, size_t type_nbytes, - size_t count, - ReduceFunction reducer); + size_t count, + ReduceFunction reducer, + PreprocFunction prepare_fun = NULL, + void *prepare_arg = NULL); /*! * \brief broadcast data from root to all nodes * \param sendrecvbuf_ buffer for both sending and recving data diff --git a/src/engine.cc b/src/engine.cc index 0512ac503..cc6a48745 100644 --- a/src/engine.cc +++ b/src/engine.cc @@ -43,8 +43,10 @@ void Allreduce_(void *sendrecvbuf, size_t count, IEngine::ReduceFunction red, mpi::DataType dtype, - mpi::OpType op) { - GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red); + mpi::OpType op, + IEngine::PreprocFunction prepare_fun, + void *prepare_arg) { + GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun, prepare_arg); } } // namespace engine } // namespace rabit diff --git a/src/engine.h b/src/engine.h index e393e94db..977b0d6ff 100644 --- a/src/engine.h +++ b/src/engine.h @@ -20,6 +20,12 @@ namespace engine { class IEngine { public: /*! + * \brief Preprocessing function, that is called before AllReduce, + * used to prepare the data used by AllReduce + * \param arg additional possible argument used to invoke the preprocessor + */ + typedef void (PreprocFunction) (void *arg); + /*! * \brief reduce function, the same form of MPI reduce function is used, * to be compatible with MPI interface * In all the functions, the memory is ensured to aligned to 64-bit @@ -34,17 +40,23 @@ class IEngine { void *dst, int count, const MPI::Datatype &dtype); /*! - * \brief perform in-place allreduce, on sendrecvbuf + * \brief perform in-place allreduce, on sendrecvbuf * this function is NOT thread-safe * \param sendrecvbuf_ buffer for both sending and recving data * \param type_nbytes the unit number of bytes the type have * \param count number of elements to be reduced * \param reducer reduce function + * \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg) + * will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_. + * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called + * \param prepare_arg argument used to passed into the lazy preprocessing function */ virtual void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, - ReduceFunction reducer) = 0; + ReduceFunction reducer, + PreprocFunction prepare_fun = NULL, + void *prepare_arg = NULL) = 0; /*! * \brief broadcast data from root to all nodes * \param sendrecvbuf_ buffer for both sending and recving data @@ -145,13 +157,19 @@ enum DataType { * \param reducer reduce function * \param dtype the data type * \param op the reduce operator type + * \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg) + * will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_. + * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called + * \param prepare_arg argument used to passed into the lazy preprocessing function * */ void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, - IEngine::ReduceFunction red, + IEngine::ReduceFunction red, mpi::DataType dtype, - mpi::OpType op); + mpi::OpType op, + IEngine::PreprocFunction prepare_fun = NULL, + void *prepare_arg = NULL); } // namespace engine } // namespace rabit #endif // RABIT_ENGINE_H diff --git a/src/engine_mpi.cc b/src/engine_mpi.cc index f32dba854..9e5972e1a 100644 --- a/src/engine_mpi.cc +++ b/src/engine_mpi.cc @@ -23,7 +23,9 @@ class MPIEngine : public IEngine { virtual void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, - ReduceFunction reducer) { + ReduceFunction reducer, + PreprocFunction prepare_fun, + void *prepare_arg) { utils::Error("MPIEngine:: Allreduce is not supported, use Allreduce_ instead"); } virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) { @@ -110,7 +112,10 @@ void Allreduce_(void *sendrecvbuf, size_t count, IEngine::ReduceFunction red, mpi::DataType dtype, - mpi::OpType op) { + mpi::OpType op, + IEngine::PreprocFunction prepare_fun, + void *prepare_arg) { + if (prepare_fun != NULL) prepare_fun(prepare_arg); MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, GetType(dtype), GetOp(op)); } } // namespace engine diff --git a/src/rabit-inl.h b/src/rabit-inl.h index 95a2eb8fd..8d379f920 100644 --- a/src/rabit-inl.h +++ b/src/rabit-inl.h @@ -119,10 +119,25 @@ inline void Broadcast(std::string *sendrecv_data, int root) { // perform inplace Allreduce template -inline void Allreduce(DType *sendrecvbuf, size_t count) { +inline void Allreduce(DType *sendrecvbuf, size_t count, + void (*prepare_fun)(void *arg), + void *prepare_arg) { engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer, - engine::mpi::GetType(), OP::kType); + engine::mpi::GetType(), OP::kType, prepare_fun, prepare_arg); } + +// C++11 support for lambda prepare function +#if __cplusplus >= 201103L +inline void InvokeLambda_(void *fun) { + (*static_cast*>(fun))(); +} +template +inline void Allreduce(DType *sendrecvbuf, size_t count, std::function prepare_fun) { + engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer, + engine::mpi::GetType(), OP::kType, InvokeLambda_, &prepare_fun); +} +#endif // C++11 + // load latest check point inline int LoadCheckPoint(utils::ISerializable *global_model, utils::ISerializable *local_model) { diff --git a/src/rabit.h b/src/rabit.h index c7cde6b4b..ac17faec6 100644 --- a/src/rabit.h +++ b/src/rabit.h @@ -10,6 +10,11 @@ */ #include #include +// optionally support of lambda function in C++11, if available +#if __cplusplus >= 201103L +#include +#endif // C++11 +// rabit headers #include "./io.h" #include "./engine.h" @@ -78,11 +83,44 @@ inline void Broadcast(std::string *sendrecv_data, int root); * ... * \param sendrecvbuf buffer for both sending and recving data * \param count number of elements to be reduced + * \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg) + * will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_. + * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called + * \param prepare_arg argument used to passed into the lazy preprocessing function * \tparam OP see namespace op, reduce operator * \tparam DType type of data */ template -inline void Allreduce(DType *sendrecvbuf, size_t count); +inline void Allreduce(DType *sendrecvbuf, size_t count, + void (*prepare_fun)(void *arg) = NULL, + void *prepare_arg = NULL); + +// C++11 support for lambda prepare function +#if __cplusplus >= 201103L +/*! + * \brief perform in-place allreduce, on sendrecvbuf + * with a prepare function specified by lambda function + * Example Usage: the following code gives sum of the result + * vector data(10); + * ... + * Allreduce(&data[0], data.size(), [&]() { + * for (int i = 0; i < 10; ++i) { + * data[i] = i; + * } + * }); + * ... + * \param sendrecvbuf buffer for both sending and recving data + * \param count number of elements to be reduced + * \param prepare_func Lazy lambda preprocessing function, prepare_fun() will be invoked + * will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_. + * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called + * \tparam OP see namespace op, reduce operator + * \tparam DType type of data + */ +template +inline void Allreduce(DType *sendrecvbuf, size_t count, std::function prepare_fun); +#endif // C++11 + /*! * \brief load latest check point * \param global_model pointer to the globally shared model/state diff --git a/test/Makefile b/test/Makefile index 2f3b81251..18d876b2e 100644 --- a/test/Makefile +++ b/test/Makefile @@ -2,7 +2,7 @@ export CC = gcc export CXX = g++ export MPICXX = mpicxx export LDFLAGS= -pthread -lm -lrt -export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../src +export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../src -std=c++11 # specify tensor path BIN = speed_test test_model_recover test_local_recover diff --git a/test/test_local_recover.cpp b/test/test_local_recover.cpp index 2d2c8234c..d98c6ae48 100644 --- a/test/test_local_recover.cpp +++ b/test/test_local_recover.cpp @@ -50,14 +50,17 @@ class Model : public rabit::utils::ISerializable { inline void TestMax(Model *model, Model *local, int ntrial, int iter) { int rank = rabit::GetRank(); int nproc = rabit::GetWorldSize(); - const int z = iter + 111; - + const int z = iter + 111; std::vector ndata(model->data.size()); - for (size_t i = 0; i < ndata.size(); ++i) { - ndata[i] = (i * (rank+1)) % z + local->data[i]; - } + test::CallBegin("Allreduce::Max", ntrial, iter); - rabit::Allreduce(&ndata[0], ndata.size()); + rabit::Allreduce(&ndata[0], ndata.size(), + [&]() { + // use lambda expression to prepare the data + for (size_t i = 0; i < ndata.size(); ++i) { + ndata[i] = (i * (rank+1)) % z + local->data[i]; + } + }); test::CallEnd("Allreduce::Max", ntrial, iter); for (size_t i = 0; i < ndata.size(); ++i) {