// implementations in ctypes #define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_DEPRECATE #include #include #include #include "./rabit_wrapper.h" namespace rabit { namespace wrapper { // helper use to avoid BitOR operator template struct FHelper { inline static void Allreduce(DType *senrecvbuf_, size_t count, void (*prepare_fun)(void *arg), void *prepare_arg) { rabit::Allreduce(senrecvbuf_, count, prepare_fun, prepare_arg); } }; template struct FHelper { inline static void Allreduce(DType *senrecvbuf_, size_t count, void (*prepare_fun)(void *arg), void *prepare_arg) { utils::Error("DataType does not support bitwise or operation"); } }; template inline void Allreduce_(void *sendrecvbuf_, size_t count, engine::mpi::DataType enum_dtype, void (*prepare_fun)(void *arg), void *prepare_arg) { using namespace engine::mpi; switch (enum_dtype) { case kChar: rabit::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; case kUChar: rabit::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; case kInt: rabit::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; case kUInt: rabit::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; case kLong: rabit::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; case kULong: rabit::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; case kFloat: FHelper::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; case kDouble: FHelper::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; default: utils::Error("unknown data_type"); } } inline void Allreduce(void *sendrecvbuf, size_t count, engine::mpi::DataType enum_dtype, engine::mpi::OpType enum_op, void (*prepare_fun)(void *arg), void *prepare_arg) { using namespace engine::mpi; switch (enum_op) { case kMax: Allreduce_ (sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); return; case kMin: Allreduce_ (sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); return; case kSum: Allreduce_ (sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); return; case kBitwiseOR: Allreduce_ (sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); return; default: utils::Error("unknown enum_op"); } } } // namespace wrapper } // namespace rabit extern "C" { void RabitInit(int argc, char *argv[]) { rabit::Init(argc, argv); } void RabitFinalize(void) { rabit::Finalize(); } int RabitGetRank(void) { return rabit::GetRank(); } int RabitGetWorldSize(void) { return rabit::GetWorldSize(); } void RabitTrackerPrint(const char *msg) { std::string m(msg); rabit::TrackerPrint(m); } void RabitGetProcessorName(char *out_name, rbt_ulong *out_len, rbt_ulong max_len) { std::string s = rabit::GetProcessorName(); if (s.length() > max_len) { s.resize(max_len - 1); } strcpy(out_name, s.c_str()); *out_len = static_cast(s.length()); } void RabitBroadcast(void *sendrecv_data, rbt_ulong size, int root) { rabit::Broadcast(sendrecv_data, size, root); } void RabitAllreduce(void *sendrecvbuf, size_t count, int enum_dtype, int enum_op, void (*prepare_fun)(void *arg), void *prepare_arg) { rabit::wrapper::Allreduce (sendrecvbuf, count, static_cast(enum_dtype), static_cast(enum_op), prepare_fun, prepare_arg); } }