// Copyright by Contributors // implementations in ctypes #include #include #include #include "rabit/rabit.h" #include "rabit/c_api.h" #include "../../src/c_api/c_api_error.h" namespace rabit { namespace c_api { // helper use to avoid BitOR operator template struct FHelper { static void Allreduce(DType *senrecvbuf_, size_t count, void (*prepare_fun)(void *arg), void *prepare_arg) { rabit::Allreduce(senrecvbuf_, count, prepare_fun, prepare_arg); } }; template struct FHelper { static void Allreduce(DType *, size_t , void (*)(void *arg), void *) { utils::Error("DataType does not support bitwise AND operation"); } }; template struct FHelper { static void Allreduce(DType *, size_t , void (*)(void *arg), void *) { utils::Error("DataType does not support bitwise OR operation"); } }; template struct FHelper { static void Allreduce(DType *, size_t , void (*)(void *arg), void *) { utils::Error("DataType does not support bitwise XOR operation"); } }; template void Allreduce(void *sendrecvbuf_, size_t count, engine::mpi::DataType enum_dtype, void (*prepare_fun)(void *arg), void *prepare_arg) { using namespace engine::mpi; // NOLINT switch (enum_dtype) { case kChar: rabit::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; case kUChar: rabit::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; case kInt: rabit::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; case kUInt: rabit::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; case kLong: rabit::Allreduce (static_cast(sendrecvbuf_), // NOLINT(*) count, prepare_fun, prepare_arg); return; case kULong: rabit::Allreduce (static_cast(sendrecvbuf_), // NOLINT(*) count, prepare_fun, prepare_arg); return; case kFloat: FHelper::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; case kDouble: FHelper::Allreduce (static_cast(sendrecvbuf_), count, prepare_fun, prepare_arg); return; default: utils::Error("unknown data_type"); } } void Allreduce(void *sendrecvbuf, size_t count, engine::mpi::DataType enum_dtype, engine::mpi::OpType enum_op, void (*prepare_fun)(void *arg), void *prepare_arg) { using namespace engine::mpi; // NOLINT switch (enum_op) { case kMax: Allreduce (sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); return; case kMin: Allreduce (sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); return; case kSum: Allreduce (sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); return; case kBitwiseAND: Allreduce (sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); return; case kBitwiseOR: Allreduce (sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); return; case kBitwiseXOR: Allreduce (sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); return; default: utils::Error("unknown enum_op"); } } void Allgather(void *sendrecvbuf_, size_t total_size, size_t beginIndex, size_t size_node_slice, size_t size_prev_slice, int enum_dtype) { using namespace engine::mpi; // NOLINT size_t type_size = 0; switch (enum_dtype) { case kChar: type_size = sizeof(char); rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, beginIndex * type_size, (beginIndex + size_node_slice) * type_size, size_prev_slice * type_size); break; case kUChar: type_size = sizeof(unsigned char); rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, beginIndex * type_size, (beginIndex + size_node_slice) * type_size, size_prev_slice * type_size); break; case kInt: type_size = sizeof(int); rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, beginIndex * type_size, (beginIndex + size_node_slice) * type_size, size_prev_slice * type_size); break; case kUInt: type_size = sizeof(unsigned); rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, beginIndex * type_size, (beginIndex + size_node_slice) * type_size, size_prev_slice * type_size); break; case kLong: type_size = sizeof(int64_t); rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, beginIndex * type_size, (beginIndex + size_node_slice) * type_size, size_prev_slice * type_size); break; case kULong: type_size = sizeof(uint64_t); rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, beginIndex * type_size, (beginIndex + size_node_slice) * type_size, size_prev_slice * type_size); break; case kFloat: type_size = sizeof(float); rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, beginIndex * type_size, (beginIndex + size_node_slice) * type_size, size_prev_slice * type_size); break; case kDouble: type_size = sizeof(double); rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, beginIndex * type_size, (beginIndex + size_node_slice) * type_size, size_prev_slice * type_size); break; default: utils::Error("unknown data_type"); } } // wrapper for serialization struct ReadWrapper : public Serializable { std::string *p_str; explicit ReadWrapper(std::string *p_str) : p_str(p_str) {} void Load(Stream *fi) override { uint64_t sz; utils::Assert(fi->Read(&sz, sizeof(sz)) != 0, "Read pickle string"); p_str->resize(sz); if (sz != 0) { utils::Assert(fi->Read(&(*p_str)[0], sizeof(char) * sz) != 0, "Read pickle string"); } } void Save(Stream *) const override { utils::Error("not implemented"); } }; struct WriteWrapper : public Serializable { const char *data; size_t length; explicit WriteWrapper(const char *data, size_t length) : data(data), length(length) { } void Load(Stream *) override { utils::Error("not implemented"); } void Save(Stream *fo) const override { uint64_t sz = static_cast(length); fo->Write(&sz, sizeof(sz)); fo->Write(data, length * sizeof(char)); } }; } // namespace c_api } // namespace rabit RABIT_DLL bool RabitInit(int argc, char *argv[]) { auto ret = rabit::Init(argc, argv); if (!ret) { XGBAPISetLastError("Failed to initialize RABIT."); } return ret; } RABIT_DLL int RabitFinalize() { auto ret = rabit::Finalize(); if (!ret) { XGBAPISetLastError("Failed to shutdown RABIT worker."); } return static_cast(ret); } RABIT_DLL int RabitGetRingPrevRank() { return rabit::GetRingPrevRank(); } RABIT_DLL int RabitGetRank() { return rabit::GetRank(); } RABIT_DLL int RabitGetWorldSize() { return rabit::GetWorldSize(); } RABIT_DLL int RabitIsDistributed() { return rabit::IsDistributed(); } RABIT_DLL int RabitTrackerPrint(const char *msg) { API_BEGIN() std::string m(msg); rabit::TrackerPrint(m); API_END() } RABIT_DLL void RabitGetProcessorName(char *out_name, rbt_ulong *out_len, rbt_ulong max_len) { std::string s = rabit::GetProcessorName(); if (s.length() > max_len) { s.resize(max_len - 1); } strcpy(out_name, s.c_str()); // NOLINT(*) *out_len = static_cast(s.length()); } RABIT_DLL int RabitBroadcast(void *sendrecv_data, rbt_ulong size, int root) { API_BEGIN() rabit::Broadcast(sendrecv_data, size, root); API_END() } RABIT_DLL int RabitAllgather(void *sendrecvbuf_, size_t total_size, size_t beginIndex, size_t size_node_slice, size_t size_prev_slice, int enum_dtype) { API_BEGIN() rabit::c_api::Allgather( sendrecvbuf_, total_size, beginIndex, size_node_slice, size_prev_slice, static_cast(enum_dtype)); API_END() } RABIT_DLL int RabitAllreduce(void *sendrecvbuf, size_t count, int enum_dtype, int enum_op, void (*prepare_fun)(void *arg), void *prepare_arg) { API_BEGIN() rabit::c_api::Allreduce(sendrecvbuf, count, static_cast(enum_dtype), static_cast(enum_op), prepare_fun, prepare_arg); API_END() } RABIT_DLL int RabitVersionNumber() { return rabit::VersionNumber(); } RABIT_DLL int RabitLinkTag() { return 0; }