diff --git a/.gitignore b/.gitignore index 6a2c32d90..31eed0c3b 100644 --- a/.gitignore +++ b/.gitignore @@ -32,4 +32,5 @@ *.exe *.txt *tmp* -doc \ No newline at end of file +doc +*.rabit diff --git a/Makefile b/Makefile index 27c2b1915..daa08aa79 100644 --- a/Makefile +++ b/Makefile @@ -1,20 +1,23 @@ export CC = gcc export CXX = g++ export MPICXX = mpicxx -export LDFLAGS= +export LDFLAGS= -Llib export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -Iinclude # build path BPATH=. # objectives that makes up rabit library MPIOBJ= $(BPATH)/engine_mpi.o -OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/engine_empty.o $(BPATH)/engine_mock.o +OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/engine_empty.o $(BPATH)/engine_mock.o\ + $(BPATH)/rabit_wrapper.o +SLIB= wrapper/librabit_wrapper.so ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a lib/librabit_mock.a HEADERS=src/*.h include/*.h include/rabit/*.h -.PHONY: clean all install mpi +.PHONY: clean all install mpi python -all: lib/librabit.a lib/librabit_mock.a +all: lib/librabit.a lib/librabit_mock.a $(SLIB) mpi: lib/librabit_mpi.a +python: wrapper/librabit_wrapper.so $(BPATH)/allreduce_base.o: src/allreduce_base.cc $(HEADERS) $(BPATH)/engine.o: src/engine.cc $(HEADERS) @@ -27,6 +30,9 @@ lib/librabit.a: $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/e lib/librabit_mock.a: $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine_mock.o lib/librabit_empty.a: $(BPATH)/engine_empty.o lib/librabit_mpi.a: $(MPIOBJ) +# wrapper code +$(BPATH)/rabit_wrapper.o: wrapper/rabit_wrapper.cc +wrapper/librabit_wrapper.so: $(BPATH)/rabit_wrapper.o lib/librabit.a $(OBJ) : $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) ) @@ -37,5 +43,8 @@ $(MPIOBJ) : $(ALIB): ar cr $@ $+ +$(SLIB) : + $(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) + clean: $(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~ diff --git a/wrapper/rabit_wrapper.cc b/wrapper/rabit_wrapper.cc new file mode 100644 index 000000000..60c53f93a --- /dev/null +++ b/wrapper/rabit_wrapper.cc @@ -0,0 +1,164 @@ +// implementations in ctypes +#define _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_DEPRECATE + +#include +#include +#include +#include "./rabit_wrapper.h" +namespace rabit { +namespace wrapper { +// helper use to avoid BitOR operator +template +struct FHelper { + inline static void + Allreduce(DType *senrecvbuf_, + size_t count, + void (*prepare_fun)(void *arg), + void *prepare_arg) { + rabit::Allreduce(senrecvbuf_, count, + prepare_fun, prepare_arg); + } +}; +template +struct FHelper { + inline static void + Allreduce(DType *senrecvbuf_, + size_t count, + void (*prepare_fun)(void *arg), + void *prepare_arg) { + utils::Error("DataType does not support bitwise or operation"); + } +}; +template +inline void Allreduce_(void *sendrecvbuf_, + size_t count, + engine::mpi::DataType enum_dtype, + void (*prepare_fun)(void *arg), + void *prepare_arg) { + using namespace engine::mpi; + switch (enum_dtype) { + case kChar: + rabit::Allreduce + (static_cast(sendrecvbuf_), + count, prepare_fun, prepare_arg); + return; + case kUChar: + rabit::Allreduce + (static_cast(sendrecvbuf_), + count, prepare_fun, prepare_arg); + return; + case kInt: + rabit::Allreduce + (static_cast(sendrecvbuf_), + count, prepare_fun, prepare_arg); + return; + case kUInt: + rabit::Allreduce + (static_cast(sendrecvbuf_), + count, prepare_fun, prepare_arg); + return; + case kLong: + rabit::Allreduce + (static_cast(sendrecvbuf_), + count, prepare_fun, prepare_arg); + return; + case kULong: + rabit::Allreduce + (static_cast(sendrecvbuf_), + count, prepare_fun, prepare_arg); + return; + case kFloat: + FHelper::Allreduce + (static_cast(sendrecvbuf_), + count, prepare_fun, prepare_arg); + return; + case kDouble: + FHelper::Allreduce + (static_cast(sendrecvbuf_), + count, prepare_fun, prepare_arg); + return; + default: utils::Error("unknown data_type"); + } +} +inline void Allreduce(void *sendrecvbuf, + size_t count, + engine::mpi::DataType enum_dtype, + engine::mpi::OpType enum_op, + void (*prepare_fun)(void *arg), + void *prepare_arg) { + using namespace engine::mpi; + switch (enum_op) { + case kMax: + Allreduce_ + (sendrecvbuf, + count, enum_dtype, + prepare_fun, prepare_arg); + return; + case kMin: + Allreduce_ + (sendrecvbuf, + count, enum_dtype, + prepare_fun, prepare_arg); + return; + case kSum: + Allreduce_ + (sendrecvbuf, + count, enum_dtype, + prepare_fun, prepare_arg); + return; + case kBitwiseOR: + Allreduce_ + (sendrecvbuf, + count, enum_dtype, + prepare_fun, prepare_arg); + return; + default: utils::Error("unknown enum_op"); + } +} +} // namespace wrapper +} // namespace rabit +extern "C" { + void RabitInit(int argc, char *argv[]) { + rabit::Init(argc, argv); + } + void RabitFinalize(void) { + rabit::Finalize(); + } + int RabitGetRank(void) { + return rabit::GetRank(); + } + int RabitGetWorldSize(void) { + return rabit::GetWorldSize(); + } + void RabitTrackerPrint(const char *msg) { + std::string m(msg); + rabit::TrackerPrint(m); + } + void RabitGetProcessorName(char *out_name, + rbt_ulong *out_len, + rbt_ulong max_len) { + std::string s = rabit::GetProcessorName(); + if (s.length() > max_len) { + s.resize(max_len - 1); + } + strcpy(out_name, s.c_str()); + *out_len = static_cast(s.length()); + } + void RabitBroadcast(void *sendrecv_data, + rbt_ulong size, int root) { + rabit::Broadcast(sendrecv_data, size, root); + } + void RabitAllreduce(void *sendrecvbuf, + size_t count, + int enum_dtype, + int enum_op, + void (*prepare_fun)(void *arg), + void *prepare_arg) { + rabit::wrapper::Allreduce + (sendrecvbuf, count, + static_cast(enum_dtype), + static_cast(enum_op), + prepare_fun, prepare_arg); + } +} diff --git a/wrapper/rabit_wrapper.h b/wrapper/rabit_wrapper.h new file mode 100644 index 000000000..823f199b9 --- /dev/null +++ b/wrapper/rabit_wrapper.h @@ -0,0 +1,87 @@ +#ifndef RABIT_WRAPPER_H_ +#define RABIT_WRAPPER_H_ +/*! + * \file rabit_wrapper.h + * \author Tianqi Chen + * \brief a C style wrapper of rabit + * can be used to create wrapper of other languages + */ +#ifdef _MSC_VER +#define RABIT_DLL __declspec(dllexport) +#else +#define RABIT_DLL +#endif +// manually define unsign long +typedef unsigned long rbt_ulong; + +#ifdef __cplusplus +extern "C" { +#endif +/*! + * \brief intialize the rabit module, call this once before using anything + * \param argc number of arguments in argv + * \param argv the array of input arguments + */ + RABIT_DLL void RabitInit(int argc, char *argv[]); + /*! + * \brief finalize the rabit engine, call this function after you finished all jobs + */ + RABIT_DLL void RabitFinalize(void); + /*! \brief get rank of current process */ + RABIT_DLL int RabitGetRank(void); + /*! \brief get total number of process */ + RABIT_DLL int RabitGetWorldSize(void); + /*! + * \brief print the msg to the tracker, + * this function can be used to communicate the information of the progress to + * the user who monitors the tracker + * \param msg the message to be printed + */ + RABIT_DLL void RabitTrackerPrint(const char *msg); + /*! + * \brief get name of processor + * \param out_name hold output string + * \param out_len hold length of output string + * \param max_len maximum buffer length of input + */ + RABIT_DLL void RabitGetProcessorName(char *out_name, + rbt_ulong *out_len, + rbt_ulong max_len); + /*! + * \brief broadcast an memory region to all others from root + * + * Example: int a = 1; Broadcast(&a, sizeof(a), root); + * \param sendrecv_data the pointer to send or recive buffer, + * \param size the size of the data + * \param root the root of process + */ + RABIT_DLL void RabitBroadcast(void *sendrecv_data, + rbt_ulong size, int root); + /*! + * \brief perform in-place allreduce, on sendrecvbuf + * this function is NOT thread-safe + * + * Example Usage: the following code gives sum of the result + * vector data(10); + * ... + * Allreduce(&data[0], data.size()); + * ... + * \param sendrecvbuf buffer for both sending and recving data + * \param count number of elements to be reduced + * \param enum_dtype the enumeration of data type, see rabit::engine::mpi::DataType in engine.h of rabit include + * \param enum_op the enumeration of operation type, see rabit::engine::mpi::OpType in engine.h of rabit + * \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg) + * will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_. + * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called + * \param prepare_arg argument used to passed into the lazy preprocessing function + */ + RABIT_DLL void RabitAllreduce(void *sendrecvbuf, + size_t count, + int enum_dtype, + int enum_op, + void (*prepare_fun)(void *arg), + void *prepare_arg); +#ifdef __cplusplus +} // C +#endif +#endif // XGBOOST_WRAPPER_H_