xgboost/rabit/src/engine_mpi.cc

163 lines
4.7 KiB
C++

/*!
* Copyright (c) 2014 by Contributors
* \file engine_mpi.cc
* \brief this file gives an implementation of engine interface using MPI,
* this will allow rabit program to run with MPI, but do not comes with fault tolerant
*
* \author Tianqi Chen
*/
#define NOMINMAX
#include <mpi.h>
#include <rabit/base.h>
#include <cstdio>
#include <string>
#include "rabit/internal/engine.h"
#include "rabit/internal/utils.h"
namespace rabit {
namespace engine {
/*! \brief implementation of engine using MPI */
class MPIEngine : public IEngine {
public:
MPIEngine(void) {
version_number = 0;
}
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
size_t slice_end, size_t size_prev_slice) override {
utils::Error("MPIEngine:: Allgather is not supported");
}
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
ReduceFunction reducer, PreprocFunction prepare_fun,
void *prepare_arg) override {
utils::Error("MPIEngine:: Allreduce is not supported,"\
"use Allreduce_ instead");
}
int GetRingPrevRank(void) const override {
utils::Error("MPIEngine:: GetRingPrevRank is not supported");
return -1;
}
void Broadcast(void *sendrecvbuf_, size_t size, int root) override {
MPI::COMM_WORLD.Bcast(sendrecvbuf_, size, MPI::CHAR, root);
}
virtual void InitAfterException(void) {
utils::Error("MPI is not fault tolerant");
}
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL) {
return 0;
}
virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL) {
version_number += 1;
}
virtual void LazyCheckPoint(const Serializable *global_model) {
version_number += 1;
}
virtual int VersionNumber(void) const {
return version_number;
}
/*! \brief get rank of current node */
virtual int GetRank(void) const {
return MPI::COMM_WORLD.Get_rank();
}
/*! \brief get total number of */
virtual int GetWorldSize(void) const {
return MPI::COMM_WORLD.Get_size();
}
/*! \brief whether it is distributed */
virtual bool IsDistributed(void) const {
return true;
}
/*! \brief get the host name of current node */
virtual std::string GetHost(void) const {
int len;
char name[MPI_MAX_PROCESSOR_NAME];
MPI::Get_processor_name(name, len);
name[len] = '\0';
return std::string(name);
}
virtual void TrackerPrint(const std::string &msg) {
// simply print information into the tracker
if (GetRank() == 0) {
utils::Printf("%s", msg.c_str());
}
}
private:
int version_number;
};
// singleton sync manager
MPIEngine manager;
/*! \brief initialize the synchronization module */
bool Init(int argc, char *argv[]) {
try {
MPI::Init(argc, argv);
return true;
} catch (const std::exception& e) {
fprintf(stderr, " failed in MPI Init %s\n", e.what());
return false;
}
}
/*! \brief finalize syncrhonization module */
bool Finalize(void) {
try {
MPI::Finalize();
return true;
} catch (const std::exception& e) {
fprintf(stderr, "failed in MPI shutdown %s\n", e.what());
return false;
}
}
/*! \brief singleton method to get engine */
IEngine *GetEngine(void) {
return &manager;
}
// transform enum to MPI data type
inline MPI::Datatype GetType(mpi::DataType dtype) {
using namespace mpi;
switch (dtype) {
case kChar: return MPI::CHAR;
case kUChar: return MPI::BYTE;
case kInt: return MPI::INT;
case kUInt: return MPI::UNSIGNED;
case kLong: return MPI::LONG;
case kULong: return MPI::UNSIGNED_LONG;
case kFloat: return MPI::FLOAT;
case kDouble: return MPI::DOUBLE;
case kLongLong: return MPI::LONG_LONG;
case kULongLong: return MPI::UNSIGNED_LONG_LONG;
}
utils::Error("unknown mpi::DataType");
return MPI::CHAR;
}
// transform enum to MPI OP
inline MPI::Op GetOp(mpi::OpType otype) {
using namespace mpi;
switch (otype) {
case kMax: return MPI::MAX;
case kMin: return MPI::MIN;
case kSum: return MPI::SUM;
case kBitwiseOR: return MPI::BOR;
}
utils::Error("unknown mpi::OpType");
return MPI::MAX;
}
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf,
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf,
count, GetType(dtype), GetOp(op));
}
} // namespace engine
} // namespace rabit