// Copyright by Contributors // implementations in ctypes #include #include #include #include "rabit/rabit.h" #include "rabit/c_api.h" namespace rabit { namespace c_api { // helper use to avoid BitOR operator template struct FHelper { static void Allreduce(DType *senrecvbuf_, size_t count, void (*prepare_fun)(void *arg), void *prepare_arg) { rabit::Allreduce(senrecvbuf_, count, prepare_fun, prepare_arg); } }; template struct FHelper { static void Allreduce(DType *senrecvbuf_, size_t count, void (*prepare_fun)(void *arg), void *prepare_arg) { utils::Error("DataType does not support bitwise or operation"); } }; template void Allreduce_(void *sendrecvbuf_, size_t count, engine::mpi::DataType enum_dtype, void (*prepare_fun)(void *arg), void *prepare_arg) { using namespace engine::mpi; switch (enum_dtype) { case kChar: rabit::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; case kUChar: rabit::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; case kInt: rabit::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; case kUInt: rabit::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; case kLong: rabit::Allreduce (static_cast(sendrecvbuf_), // NOLINT(*) count, prepare_fun, prepare_arg); return; case kULong: rabit::Allreduce (static_cast(sendrecvbuf_), // NOLINT(*) count, prepare_fun, prepare_arg); return; case kFloat: FHelper::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; case kDouble: FHelper::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; default: utils::Error("unknown data_type"); } } void Allreduce(void *sendrecvbuf, size_t count, engine::mpi::DataType enum_dtype, engine::mpi::OpType enum_op, void (*prepare_fun)(void *arg), void *prepare_arg) { using namespace engine::mpi; switch (enum_op) { case kMax: Allreduce_ (sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); return; case kMin: Allreduce_ (sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); return; case kSum: Allreduce_ (sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); return; case kBitwiseOR: Allreduce_ (sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); return; default: utils::Error("unknown enum_op"); } } void Allgather(void *sendrecvbuf_, size_t total_size, size_t beginIndex, size_t size_node_slice, size_t size_prev_slice, int enum_dtype) { using namespace engine::mpi; size_t type_size = 0; switch (enum_dtype) { case kChar: type_size = sizeof(char); rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, beginIndex * type_size, (beginIndex + size_node_slice) * type_size, size_prev_slice * type_size); break; case kUChar: type_size = sizeof(unsigned char); rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, beginIndex * type_size, (beginIndex + size_node_slice) * type_size, size_prev_slice * type_size); break; case kInt: type_size = sizeof(int); rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, beginIndex * type_size, (beginIndex + size_node_slice) * type_size, size_prev_slice * type_size); break; case kUInt: type_size = sizeof(unsigned); rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, beginIndex * type_size, (beginIndex + size_node_slice) * type_size, size_prev_slice * type_size); break; case kLong: type_size = sizeof(int64_t); rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, beginIndex * type_size, (beginIndex + size_node_slice) * type_size, size_prev_slice * type_size); break; case kULong: type_size = sizeof(uint64_t); rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, beginIndex * type_size, (beginIndex + size_node_slice) * type_size, size_prev_slice * type_size); break; case kFloat: type_size = sizeof(float); rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, beginIndex * type_size, (beginIndex + size_node_slice) * type_size, size_prev_slice * type_size); break; case kDouble: type_size = sizeof(double); rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, beginIndex * type_size, (beginIndex + size_node_slice) * type_size, size_prev_slice * type_size); break; default: utils::Error("unknown data_type"); } } // wrapper for serialization struct ReadWrapper : public Serializable { std::string *p_str; explicit ReadWrapper(std::string *p_str) : p_str(p_str) {} virtual void Load(Stream *fi) { uint64_t sz; utils::Assert(fi->Read(&sz, sizeof(sz)) != 0, "Read pickle string"); p_str->resize(sz); if (sz != 0) { utils::Assert(fi->Read(&(*p_str)[0], sizeof(char) * sz) != 0, "Read pickle string"); } } virtual void Save(Stream *fo) const { utils::Error("not implemented"); } }; struct WriteWrapper : public Serializable { const char *data; size_t length; explicit WriteWrapper(const char *data, size_t length) : data(data), length(length) { } virtual void Load(Stream *fi) { utils::Error("not implemented"); } virtual void Save(Stream *fo) const { uint64_t sz = static_cast(length); fo->Write(&sz, sizeof(sz)); fo->Write(data, length * sizeof(char)); } }; } // namespace c_api } // namespace rabit RABIT_DLL bool RabitInit(int argc, char *argv[]) { return rabit::Init(argc, argv); } RABIT_DLL bool RabitFinalize() { return rabit::Finalize(); } RABIT_DLL int RabitGetRingPrevRank() { return rabit::GetRingPrevRank(); } RABIT_DLL int RabitGetRank() { return rabit::GetRank(); } RABIT_DLL int RabitGetWorldSize() { return rabit::GetWorldSize(); } RABIT_DLL int RabitIsDistributed() { return rabit::IsDistributed(); } RABIT_DLL void RabitTrackerPrint(const char *msg) { std::string m(msg); rabit::TrackerPrint(m); } RABIT_DLL void RabitGetProcessorName(char *out_name, rbt_ulong *out_len, rbt_ulong max_len) { std::string s = rabit::GetProcessorName(); if (s.length() > max_len) { s.resize(max_len - 1); } strcpy(out_name, s.c_str()); // NOLINT(*) *out_len = static_cast(s.length()); } RABIT_DLL void RabitBroadcast(void *sendrecv_data, rbt_ulong size, int root) { rabit::Broadcast(sendrecv_data, size, root); } RABIT_DLL void RabitAllgather(void *sendrecvbuf_, size_t total_size, size_t beginIndex, size_t size_node_slice, size_t size_prev_slice, int enum_dtype) { rabit::c_api::Allgather(sendrecvbuf_, total_size, beginIndex, size_node_slice, size_prev_slice, static_cast(enum_dtype)); } RABIT_DLL void RabitAllreduce(void *sendrecvbuf, size_t count, int enum_dtype, int enum_op, void (*prepare_fun)(void *arg), void *prepare_arg) { rabit::c_api::Allreduce (sendrecvbuf, count, static_cast(enum_dtype), static_cast(enum_op), prepare_fun, prepare_arg); } RABIT_DLL int RabitLoadCheckPoint(char **out_global_model, rbt_ulong *out_global_len, char **out_local_model, rbt_ulong *out_local_len) { // NOTE: this function is not thread-safe using rabit::BeginPtr; using namespace rabit::c_api; // NOLINT(*) static std::string global_buffer; static std::string local_buffer; ReadWrapper sg(&global_buffer); ReadWrapper sl(&local_buffer); int version; if (out_local_model == NULL) { version = rabit::LoadCheckPoint(&sg, NULL); *out_global_model = BeginPtr(global_buffer); *out_global_len = static_cast(global_buffer.length()); } else { version = rabit::LoadCheckPoint(&sg, &sl); *out_global_model = BeginPtr(global_buffer); *out_global_len = static_cast(global_buffer.length()); *out_local_model = BeginPtr(local_buffer); *out_local_len = static_cast(local_buffer.length()); } return version; } RABIT_DLL void RabitCheckPoint(const char *global_model, rbt_ulong global_len, const char *local_model, rbt_ulong local_len) { using namespace rabit::c_api; // NOLINT(*) WriteWrapper sg(global_model, global_len); WriteWrapper sl(local_model, local_len); if (local_model == NULL) { rabit::CheckPoint(&sg, NULL); } else { rabit::CheckPoint(&sg, &sl); } } RABIT_DLL int RabitVersionNumber() { return rabit::VersionNumber(); } RABIT_DLL int RabitLinkTag() { return 0; }