/*! * Copyright (c) 2014 by Contributors * \file engine_mpi.cc * \brief this file gives an implementation of engine interface using MPI, * this will allow rabit program to run with MPI, but do not comes with fault tolerant * * \author Tianqi Chen */ #define NOMINMAX #include #include #include #include #include "rabit/internal/engine.h" #include "rabit/internal/utils.h" namespace rabit { namespace engine { /*! \brief implementation of engine using MPI */ class MPIEngine : public IEngine { public: MPIEngine(void) { version_number = 0; } void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin, size_t slice_end, size_t size_prev_slice) override { utils::Error("MPIEngine:: Allgather is not supported"); } void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, ReduceFunction reducer, PreprocFunction prepare_fun, void *prepare_arg) override { utils::Error("MPIEngine:: Allreduce is not supported,"\ "use Allreduce_ instead"); } int GetRingPrevRank(void) const override { utils::Error("MPIEngine:: GetRingPrevRank is not supported"); return -1; } void Broadcast(void *sendrecvbuf_, size_t size, int root) override { MPI::COMM_WORLD.Bcast(sendrecvbuf_, size, MPI::CHAR, root); } virtual void InitAfterException(void) { utils::Error("MPI is not fault tolerant"); } virtual int LoadCheckPoint(Serializable *global_model, Serializable *local_model = NULL) { return 0; } virtual void CheckPoint(const Serializable *global_model, const Serializable *local_model = NULL) { version_number += 1; } virtual void LazyCheckPoint(const Serializable *global_model) { version_number += 1; } virtual int VersionNumber(void) const { return version_number; } /*! \brief get rank of current node */ virtual int GetRank(void) const { return MPI::COMM_WORLD.Get_rank(); } /*! \brief get total number of */ virtual int GetWorldSize(void) const { return MPI::COMM_WORLD.Get_size(); } /*! \brief whether it is distributed */ virtual bool IsDistributed(void) const { return true; } /*! \brief get the host name of current node */ virtual std::string GetHost(void) const { int len; char name[MPI_MAX_PROCESSOR_NAME]; MPI::Get_processor_name(name, len); name[len] = '\0'; return std::string(name); } virtual void TrackerPrint(const std::string &msg) { // simply print information into the tracker if (GetRank() == 0) { utils::Printf("%s", msg.c_str()); } } private: int version_number; }; // singleton sync manager MPIEngine manager; /*! \brief initialize the synchronization module */ bool Init(int argc, char *argv[]) { try { MPI::Init(argc, argv); return true; } catch (const std::exception& e) { fprintf(stderr, " failed in MPI Init %s\n", e.what()); return false; } } /*! \brief finalize syncrhonization module */ bool Finalize(void) { try { MPI::Finalize(); return true; } catch (const std::exception& e) { fprintf(stderr, "failed in MPI shutdown %s\n", e.what()); return false; } } /*! \brief singleton method to get engine */ IEngine *GetEngine(void) { return &manager; } // transform enum to MPI data type inline MPI::Datatype GetType(mpi::DataType dtype) { using namespace mpi; switch (dtype) { case kChar: return MPI::CHAR; case kUChar: return MPI::BYTE; case kInt: return MPI::INT; case kUInt: return MPI::UNSIGNED; case kLong: return MPI::LONG; case kULong: return MPI::UNSIGNED_LONG; case kFloat: return MPI::FLOAT; case kDouble: return MPI::DOUBLE; case kLongLong: return MPI::LONG_LONG; case kULongLong: return MPI::UNSIGNED_LONG_LONG; } utils::Error("unknown mpi::DataType"); return MPI::CHAR; } // transform enum to MPI OP inline MPI::Op GetOp(mpi::OpType otype) { using namespace mpi; switch (otype) { case kMax: return MPI::MAX; case kMin: return MPI::MIN; case kSum: return MPI::SUM; case kBitwiseOR: return MPI::BOR; } utils::Error("unknown mpi::OpType"); return MPI::MAX; } // perform in-place allreduce, on sendrecvbuf void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, IEngine::ReduceFunction red, mpi::DataType dtype, mpi::OpType op, IEngine::PreprocFunction prepare_fun, void *prepare_arg) { if (prepare_fun != NULL) prepare_fun(prepare_arg); MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, GetType(dtype), GetOp(op)); } } // namespace engine } // namespace rabit