enable support for lambda preprocessing function, and c++11

This commit is contained in:
tqchen 2014-12-19 02:00:43 -08:00
parent 58331067f8
commit 1754fdbf4e
10 changed files with 124 additions and 23 deletions

View File

@ -64,11 +64,18 @@ class AllreduceBase : public IEngine {
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
*/
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
ReduceFunction reducer,
PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) == kSuccess,
"Allreduce failed");
}

View File

@ -57,17 +57,24 @@ void AllreduceRobust::SetParam(const char *name, const char *val) {
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
*/
void AllreduceRobust::Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg) {
bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
// now we are free to remove the last result, if any
if (resbuf.LastSeqNo() != -1 &&
(resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
resbuf.DropLast();
}
if (!recovered && prepare_fun != NULL) prepare_fun(prepare_arg);
void *temp = resbuf.AllocTemp(type_nbytes, count);
while (true) {
if (recovered) {

View File

@ -35,11 +35,17 @@ class AllreduceRobust : public AllreduceBase {
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
*/
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
*/
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer);
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL);
/*!
* \brief broadcast data from root to all nodes
* \param sendrecvbuf_ buffer for both sending and recving data

View File

@ -43,8 +43,10 @@ void Allreduce_(void *sendrecvbuf,
size_t count,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op) {
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red);
mpi::OpType op,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun, prepare_arg);
}
} // namespace engine
} // namespace rabit

View File

@ -20,6 +20,12 @@ namespace engine {
class IEngine {
public:
/*!
* \brief Preprocessing function, that is called before AllReduce,
* used to prepare the data used by AllReduce
* \param arg additional possible argument used to invoke the preprocessor
*/
typedef void (PreprocFunction) (void *arg);
/*!
* \brief reduce function, the same form of MPI reduce function is used,
* to be compatible with MPI interface
* In all the functions, the memory is ensured to aligned to 64-bit
@ -34,17 +40,23 @@ class IEngine {
void *dst, int count,
const MPI::Datatype &dtype);
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
*/
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) = 0;
ReduceFunction reducer,
PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL) = 0;
/*!
* \brief broadcast data from root to all nodes
* \param sendrecvbuf_ buffer for both sending and recving data
@ -145,13 +157,19 @@ enum DataType {
* \param reducer reduce function
* \param dtype the data type
* \param op the reduce operator type
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function *
*/
void Allreduce_(void *sendrecvbuf,
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op);
mpi::OpType op,
IEngine::PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL);
} // namespace engine
} // namespace rabit
#endif // RABIT_ENGINE_H

View File

@ -23,7 +23,9 @@ class MPIEngine : public IEngine {
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg) {
utils::Error("MPIEngine:: Allreduce is not supported, use Allreduce_ instead");
}
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
@ -110,7 +112,10 @@ void Allreduce_(void *sendrecvbuf,
size_t count,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op) {
mpi::OpType op,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, GetType(dtype), GetOp(op));
}
} // namespace engine

View File

@ -119,10 +119,25 @@ inline void Broadcast(std::string *sendrecv_data, int root) {
// perform inplace Allreduce
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count) {
inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>,
engine::mpi::GetType<DType>(), OP::kType);
engine::mpi::GetType<DType>(), OP::kType, prepare_fun, prepare_arg);
}
// C++11 support for lambda prepare function
#if __cplusplus >= 201103L
inline void InvokeLambda_(void *fun) {
(*static_cast<std::function<void()>*>(fun))();
}
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> prepare_fun) {
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>,
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda_, &prepare_fun);
}
#endif // C++11
// load latest check point
inline int LoadCheckPoint(utils::ISerializable *global_model,
utils::ISerializable *local_model) {

View File

@ -10,6 +10,11 @@
*/
#include <string>
#include <vector>
// optionally support of lambda function in C++11, if available
#if __cplusplus >= 201103L
#include <functional>
#endif // C++11
// rabit headers
#include "./io.h"
#include "./engine.h"
@ -78,11 +83,44 @@ inline void Broadcast(std::string *sendrecv_data, int root);
* ...
* \param sendrecvbuf buffer for both sending and recving data
* \param count number of elements to be reduced
* \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
* \tparam OP see namespace op, reduce operator
* \tparam DType type of data
*/
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count);
inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg) = NULL,
void *prepare_arg = NULL);
// C++11 support for lambda prepare function
#if __cplusplus >= 201103L
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* with a prepare function specified by lambda function
* Example Usage: the following code gives sum of the result
* vector<int> data(10);
* ...
* Allreduce<op::Sum>(&data[0], data.size(), [&]() {
* for (int i = 0; i < 10; ++i) {
* data[i] = i;
* }
* });
* ...
* \param sendrecvbuf buffer for both sending and recving data
* \param count number of elements to be reduced
* \param prepare_func Lazy lambda preprocessing function, prepare_fun() will be invoked
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \tparam OP see namespace op, reduce operator
* \tparam DType type of data
*/
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> prepare_fun);
#endif // C++11
/*!
* \brief load latest check point
* \param global_model pointer to the globally shared model/state

View File

@ -2,7 +2,7 @@ export CC = gcc
export CXX = g++
export MPICXX = mpicxx
export LDFLAGS= -pthread -lm -lrt
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../src
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../src -std=c++11
# specify tensor path
BIN = speed_test test_model_recover test_local_recover

View File

@ -50,14 +50,17 @@ class Model : public rabit::utils::ISerializable {
inline void TestMax(Model *model, Model *local, int ntrial, int iter) {
int rank = rabit::GetRank();
int nproc = rabit::GetWorldSize();
const int z = iter + 111;
const int z = iter + 111;
std::vector<float> ndata(model->data.size());
for (size_t i = 0; i < ndata.size(); ++i) {
ndata[i] = (i * (rank+1)) % z + local->data[i];
}
test::CallBegin("Allreduce::Max", ntrial, iter);
rabit::Allreduce<op::Max>(&ndata[0], ndata.size());
rabit::Allreduce<op::Max>(&ndata[0], ndata.size(),
[&]() {
// use lambda expression to prepare the data
for (size_t i = 0; i < ndata.size(); ++i) {
ndata[i] = (i * (rank+1)) % z + local->data[i];
}
});
test::CallEnd("Allreduce::Max", ntrial, iter);
for (size_t i = 0; i < ndata.size(); ++i) {