rabit run on MPI
This commit is contained in:
parent
2fab05c83e
commit
0a3300d773
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* \file engine_base.cc
|
* \file allreduce_base.cc
|
||||||
* \brief Basic implementation of AllReduce
|
* \brief Basic implementation of AllReduce
|
||||||
*
|
*
|
||||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||||
@ -8,7 +8,7 @@
|
|||||||
#define _CRT_SECURE_NO_DEPRECATE
|
#define _CRT_SECURE_NO_DEPRECATE
|
||||||
#define NOMINMAX
|
#define NOMINMAX
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include "./engine_base.h"
|
#include "./allreduce_base.h"
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* \file engine_base.h
|
* \file allreduce_base.h
|
||||||
* \brief Basic implementation of AllReduce
|
* \brief Basic implementation of AllReduce
|
||||||
* using TCP non-block socket and tree-shape reduction.
|
* using TCP non-block socket and tree-shape reduction.
|
||||||
*
|
*
|
||||||
@ -8,8 +8,8 @@
|
|||||||
*
|
*
|
||||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||||
*/
|
*/
|
||||||
#ifndef RABIT_ENGINE_BASE_H
|
#ifndef RABIT_ALLREDUCE_BASE_H
|
||||||
#define RABIT_ENGINE_BASE_H
|
#define RABIT_ALLREDUCE_BASE_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
@ -275,4 +275,4 @@ class AllReduceBase : public IEngine {
|
|||||||
};
|
};
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
#endif // RABIT_ENGINE_BASE_H
|
#endif // RABIT_ALLREDUCE_BASE_H
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* \file engine_robust-inl.h
|
* \file allreduce_robust-inl.h
|
||||||
* \brief implementation of inline template function in AllReduceRobust
|
* \brief implementation of inline template function in AllReduceRobust
|
||||||
*
|
*
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* \file engine_robust.cc
|
* \file allreduce_robust.cc
|
||||||
* \brief Robust implementation of AllReduce
|
* \brief Robust implementation of AllReduce
|
||||||
*
|
*
|
||||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||||
@ -11,7 +11,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include "./io.h"
|
#include "./io.h"
|
||||||
#include "./utils.h"
|
#include "./utils.h"
|
||||||
#include "./engine_robust.h"
|
#include "./allreduce_robust.h"
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* \file engine_robust.h
|
* \file allreduce_robust.h
|
||||||
* \brief Robust implementation of AllReduce
|
* \brief Robust implementation of AllReduce
|
||||||
* using TCP non-block socket and tree-shape reduction.
|
* using TCP non-block socket and tree-shape reduction.
|
||||||
*
|
*
|
||||||
@ -7,11 +7,11 @@
|
|||||||
*
|
*
|
||||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||||
*/
|
*/
|
||||||
#ifndef RABIT_ENGINE_ROBUST_H
|
#ifndef RABIT_ALLREDUCE_ROBUST_H
|
||||||
#define RABIT_ENGINE_ROBUST_H
|
#define RABIT_ALLREDUCE_ROBUST_H
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "./engine.h"
|
#include "./engine.h"
|
||||||
#include "./engine_base.h"
|
#include "./allreduce_base.h"
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
@ -70,7 +70,7 @@ class AllReduceRobust : public AllReduceBase {
|
|||||||
* this function is only used for test purpose
|
* this function is only used for test purpose
|
||||||
*/
|
*/
|
||||||
virtual void InitAfterException(void) {
|
virtual void InitAfterException(void) {
|
||||||
this->CheckAndRecover(kGetExcept);
|
//this->CheckAndRecover(kGetExcept);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -371,6 +371,6 @@ class AllReduceRobust : public AllReduceBase {
|
|||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
// implementation of inline template function
|
// implementation of inline template function
|
||||||
#include "./engine_robust-inl.h"
|
#include "./allreduce_robust-inl.h"
|
||||||
|
|
||||||
#endif // RABIT_ENGINE_ROBUST_H
|
#endif // RABIT_ALLREDUCE_ROBUST_H
|
||||||
@ -10,8 +10,8 @@
|
|||||||
#define NOMINMAX
|
#define NOMINMAX
|
||||||
|
|
||||||
#include "./engine.h"
|
#include "./engine.h"
|
||||||
#include "./engine_base.h"
|
#include "./allreduce_base.h"
|
||||||
#include "./engine_robust.h"
|
#include "./allreduce_robust.h"
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
@ -37,5 +37,14 @@ void Finalize(void) {
|
|||||||
IEngine *GetEngine(void) {
|
IEngine *GetEngine(void) {
|
||||||
return &manager;
|
return &manager;
|
||||||
}
|
}
|
||||||
|
// perform in-place allreduce, on sendrecvbuf
|
||||||
|
void AllReduce_(void *sendrecvbuf,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
IEngine::ReduceFunction red,
|
||||||
|
mpi::DataType dtype,
|
||||||
|
mpi::OpType op) {
|
||||||
|
GetEngine()->AllReduce(sendrecvbuf, type_nbytes, count, red);
|
||||||
|
}
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
31
src/engine.h
31
src/engine.h
@ -105,6 +105,37 @@ void Finalize(void);
|
|||||||
/*! \brief singleton method to get engine */
|
/*! \brief singleton method to get engine */
|
||||||
IEngine *GetEngine(void);
|
IEngine *GetEngine(void);
|
||||||
|
|
||||||
|
/*! \brief namespace that contains staffs to be compatible with MPI */
|
||||||
|
namespace mpi {
|
||||||
|
/*!\brief enum of all operators */
|
||||||
|
enum OpType {
|
||||||
|
kMax, kMin, kSum, kBitwiseOR
|
||||||
|
};
|
||||||
|
/*!\brief enum of supported data types */
|
||||||
|
enum DataType {
|
||||||
|
kInt,
|
||||||
|
kUInt,
|
||||||
|
kDouble,
|
||||||
|
kFloat
|
||||||
|
};
|
||||||
|
} // namespace mpi
|
||||||
|
/*!
|
||||||
|
* \brief perform in-place allreduce, on sendrecvbuf
|
||||||
|
* this is an internal function used by rabit to be able to compile with MPI
|
||||||
|
* do not use this function directly
|
||||||
|
* \param sendrecvbuf buffer for both sending and recving data
|
||||||
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param reducer reduce function
|
||||||
|
* \param dtype the data type
|
||||||
|
* \param op the reduce operator type
|
||||||
|
*/
|
||||||
|
void AllReduce_(void *sendrecvbuf,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
IEngine::ReduceFunction red,
|
||||||
|
mpi::DataType dtype,
|
||||||
|
mpi::OpType op);
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
#endif // RABIT_ENGINE_H
|
#endif // RABIT_ENGINE_H
|
||||||
|
|||||||
115
src/engine_mpi.cc
Normal file
115
src/engine_mpi.cc
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
/*!
|
||||||
|
* \file engine_mpi.cc
|
||||||
|
* \brief this file gives an implementation of engine interface using MPI,
|
||||||
|
* this will allow rabit program to run with MPI, but do not comes with fault tolerant
|
||||||
|
*
|
||||||
|
* \author Tianqi Chen
|
||||||
|
*/
|
||||||
|
#define _CRT_SECURE_NO_WARNINGS
|
||||||
|
#define _CRT_SECURE_NO_DEPRECATE
|
||||||
|
#define NOMINMAX
|
||||||
|
#include "./engine.h"
|
||||||
|
#include "./utils.h"
|
||||||
|
#include <mpi.h>
|
||||||
|
|
||||||
|
namespace rabit {
|
||||||
|
namespace engine {
|
||||||
|
/*! \brief implementation of engine using MPI */
|
||||||
|
class MPIEngine : public IEngine {
|
||||||
|
public:
|
||||||
|
MPIEngine(void) {
|
||||||
|
version_number = 0;
|
||||||
|
}
|
||||||
|
virtual void AllReduce(void *sendrecvbuf_,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
ReduceFunction reducer) {
|
||||||
|
utils::Error("MPIEngine:: AllReduce is not supported, use AllReduce_ instead");
|
||||||
|
}
|
||||||
|
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
|
||||||
|
MPI::COMM_WORLD.Bcast(sendrecvbuf_, size, MPI::CHAR, root);
|
||||||
|
}
|
||||||
|
virtual void InitAfterException(void) {
|
||||||
|
utils::Error("MPI is not fault tolerant");
|
||||||
|
}
|
||||||
|
virtual int LoadCheckPoint(utils::ISerializable *p_model) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
virtual void CheckPoint(const utils::ISerializable &model) {
|
||||||
|
version_number += 1;
|
||||||
|
}
|
||||||
|
virtual int VersionNumber(void) const {
|
||||||
|
return version_number;
|
||||||
|
}
|
||||||
|
/*! \brief get rank of current node */
|
||||||
|
virtual int GetRank(void) const {
|
||||||
|
return MPI::COMM_WORLD.Get_rank();
|
||||||
|
}
|
||||||
|
/*! \brief get total number of */
|
||||||
|
virtual int GetWorldSize(void) const {
|
||||||
|
return MPI::COMM_WORLD.Get_size();
|
||||||
|
}
|
||||||
|
/*! \brief get the host name of current node */
|
||||||
|
virtual std::string GetHost(void) const {
|
||||||
|
int len;
|
||||||
|
char name[MPI_MAX_PROCESSOR_NAME];
|
||||||
|
MPI::Get_processor_name(name, len);
|
||||||
|
name[len] = '\0';
|
||||||
|
return std::string(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int version_number;
|
||||||
|
};
|
||||||
|
|
||||||
|
// singleton sync manager
|
||||||
|
MPIEngine manager;
|
||||||
|
|
||||||
|
/*! \brief intiialize the synchronization module */
|
||||||
|
void Init(int argc, char *argv[]) {
|
||||||
|
MPI::Init(argc, argv);
|
||||||
|
}
|
||||||
|
/*! \brief finalize syncrhonization module */
|
||||||
|
void Finalize(void) {
|
||||||
|
MPI::Finalize();
|
||||||
|
}
|
||||||
|
|
||||||
|
/*! \brief singleton method to get engine */
|
||||||
|
IEngine *GetEngine(void) {
|
||||||
|
return &manager;
|
||||||
|
}
|
||||||
|
// transform enum to MPI data type
|
||||||
|
inline MPI::Datatype GetType(mpi::DataType dtype) {
|
||||||
|
using namespace mpi;
|
||||||
|
switch(dtype) {
|
||||||
|
case kInt: return MPI::INT;
|
||||||
|
case kUInt: return MPI::UNSIGNED;
|
||||||
|
case kFloat: return MPI::FLOAT;
|
||||||
|
case kDouble: return MPI::DOUBLE;
|
||||||
|
}
|
||||||
|
utils::Error("unknown mpi::DataType");
|
||||||
|
return MPI::CHAR;
|
||||||
|
}
|
||||||
|
// transform enum to MPI OP
|
||||||
|
inline MPI::Op GetOp(mpi::OpType otype) {
|
||||||
|
using namespace mpi;
|
||||||
|
switch(otype) {
|
||||||
|
case kMax: return MPI::MAX;
|
||||||
|
case kMin: return MPI::MIN;
|
||||||
|
case kSum: return MPI::SUM;
|
||||||
|
case kBitwiseOR: return MPI::BOR;
|
||||||
|
}
|
||||||
|
utils::Error("unknown mpi::OpType");
|
||||||
|
return MPI::MAX;
|
||||||
|
}
|
||||||
|
// perform in-place allreduce, on sendrecvbuf
|
||||||
|
void AllReduce_(void *sendrecvbuf,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
IEngine::ReduceFunction red,
|
||||||
|
mpi::DataType dtype,
|
||||||
|
mpi::OpType op) {
|
||||||
|
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, GetType(dtype), GetOp(op));
|
||||||
|
}
|
||||||
|
} // namespace engine
|
||||||
|
} // namespace rabit
|
||||||
123
src/rabit-inl.h
Normal file
123
src/rabit-inl.h
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
/*!
|
||||||
|
* \file rabit-inl.h
|
||||||
|
* \brief implementation of inline template function for rabit interface
|
||||||
|
*
|
||||||
|
* \author Tianqi Chen
|
||||||
|
*/
|
||||||
|
#ifndef RABIT_RABIT_INL_H
|
||||||
|
#define RABIT_RABIT_INL_H
|
||||||
|
|
||||||
|
namespace rabit {
|
||||||
|
namespace engine {
|
||||||
|
namespace mpi {
|
||||||
|
// template function to translate type to enum indicator
|
||||||
|
template<typename DType>
|
||||||
|
inline DataType GetType(void);
|
||||||
|
template<>
|
||||||
|
inline DataType GetType<int>(void) {
|
||||||
|
return kInt;
|
||||||
|
}
|
||||||
|
template<>
|
||||||
|
inline DataType GetType<unsigned>(void) {
|
||||||
|
return kUInt;
|
||||||
|
}
|
||||||
|
template<>
|
||||||
|
inline DataType GetType<float>(void) {
|
||||||
|
return kFloat;
|
||||||
|
}
|
||||||
|
template<>
|
||||||
|
inline DataType GetType<double>(void) {
|
||||||
|
return kDouble;
|
||||||
|
}
|
||||||
|
} // namespace mpi
|
||||||
|
} // namespace engine
|
||||||
|
|
||||||
|
namespace op {
|
||||||
|
struct Max {
|
||||||
|
const static engine::mpi::OpType kType = engine::mpi::kMax;
|
||||||
|
template<typename DType>
|
||||||
|
inline static void Reduce(DType &dst, const DType &src) {
|
||||||
|
if (dst < src) dst = src;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
struct Min {
|
||||||
|
const static engine::mpi::OpType kType = engine::mpi::kMin;
|
||||||
|
template<typename DType>
|
||||||
|
inline static void Reduce(DType &dst, const DType &src) {
|
||||||
|
if (dst > src) dst = src;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
struct Sum {
|
||||||
|
const static engine::mpi::OpType kType = engine::mpi::kSum;
|
||||||
|
template<typename DType>
|
||||||
|
inline static void Reduce(DType &dst, const DType &src) {
|
||||||
|
dst += src;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
struct BitOR {
|
||||||
|
const static engine::mpi::OpType kType = engine::mpi::kBitwiseOR;
|
||||||
|
template<typename DType>
|
||||||
|
inline static void Reduce(DType &dst, const DType &src) {
|
||||||
|
dst |= src;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
template<typename OP, typename DType>
|
||||||
|
inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
|
||||||
|
const DType *src = (const DType*)src_;
|
||||||
|
DType *dst = (DType*)dst_;
|
||||||
|
for (int i = 0; i < len; ++i) {
|
||||||
|
OP::Reduce(dst[i], src[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace op
|
||||||
|
|
||||||
|
// intialize the rabit engine
|
||||||
|
inline void Init(int argc, char *argv[]) {
|
||||||
|
engine::Init(argc, argv);
|
||||||
|
}
|
||||||
|
// finalize the rabit engine
|
||||||
|
inline void Finalize(void) {
|
||||||
|
engine::Finalize();
|
||||||
|
}
|
||||||
|
// get the rank of current process
|
||||||
|
inline int GetRank(void) {
|
||||||
|
return engine::GetEngine()->GetRank();
|
||||||
|
}
|
||||||
|
// the the size of the world
|
||||||
|
inline int GetWorldSize(void) {
|
||||||
|
return engine::GetEngine()->GetWorldSize();
|
||||||
|
}
|
||||||
|
// get the name of current processor
|
||||||
|
inline std::string GetProcessorName(void) {
|
||||||
|
return engine::GetEngine()->GetHost();
|
||||||
|
}
|
||||||
|
// broadcast an std::string to all others from root
|
||||||
|
inline void Bcast(std::string *sendrecv_data, int root) {
|
||||||
|
engine::IEngine *e = engine::GetEngine();
|
||||||
|
unsigned len = static_cast<unsigned>(sendrecv_data->length());
|
||||||
|
e->Broadcast(&len, sizeof(len), root);
|
||||||
|
sendrecv_data->resize(len);
|
||||||
|
if (len != 0) {
|
||||||
|
e->Broadcast(&(*sendrecv_data)[0], len, root);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// perform inplace AllReduce
|
||||||
|
template<typename OP, typename DType>
|
||||||
|
inline void AllReduce(DType *sendrecvbuf, size_t count) {
|
||||||
|
engine::AllReduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>,
|
||||||
|
engine::mpi::GetType<DType>(), OP::kType);
|
||||||
|
}
|
||||||
|
// load latest check point
|
||||||
|
inline int LoadCheckPoint(utils::ISerializable *p_model) {
|
||||||
|
return engine::GetEngine()->LoadCheckPoint(p_model);
|
||||||
|
}
|
||||||
|
// checkpoint the model, meaning we finished a stage of execution
|
||||||
|
inline void CheckPoint(const utils::ISerializable &model) {
|
||||||
|
engine::GetEngine()->CheckPoint(model);
|
||||||
|
}
|
||||||
|
// return the version number of currently stored model
|
||||||
|
inline int VersionNumber(void) {
|
||||||
|
return engine::GetEngine()->VersionNumber();
|
||||||
|
}
|
||||||
|
} // namespace rabit
|
||||||
|
#endif
|
||||||
96
src/rabit.h
96
src/rabit.h
@ -2,8 +2,9 @@
|
|||||||
#define RABIT_RABIT_H
|
#define RABIT_RABIT_H
|
||||||
/*!
|
/*!
|
||||||
* \file rabit.h
|
* \file rabit.h
|
||||||
* \brief This file defines a template wrapper of engine to give more flexible
|
* \brief This file defines unified AllReduce/Broadcast interface of rabit
|
||||||
* AllReduce operations
|
* The actual implementation is redirected to rabit engine
|
||||||
|
* Code only using this header can also compiled with MPI AllReduce(with no fault recovery),
|
||||||
*
|
*
|
||||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||||
*/
|
*/
|
||||||
@ -13,53 +14,32 @@
|
|||||||
namespace rabit {
|
namespace rabit {
|
||||||
/*! \brief namespace of operator */
|
/*! \brief namespace of operator */
|
||||||
namespace op {
|
namespace op {
|
||||||
struct Max {
|
/*! \brief maximum value */
|
||||||
template<typename DType>
|
struct Max;
|
||||||
inline static void Reduce(DType &dst, const DType &src) {
|
/*! \brief minimum value */
|
||||||
if (dst < src) dst = src;
|
struct Min;
|
||||||
}
|
/*! \brief perform sum */
|
||||||
};
|
struct Sum;
|
||||||
struct Sum {
|
/*! \brief perform bitwise OR */
|
||||||
template<typename DType>
|
struct BitOR;
|
||||||
inline static void Reduce(DType &dst, const DType &src) {
|
|
||||||
dst += src;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
struct BitOR {
|
|
||||||
template<typename DType>
|
|
||||||
inline static void Reduce(DType &dst, const DType &src) {
|
|
||||||
dst |= src;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
template<typename OP, typename DType>
|
|
||||||
inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
|
|
||||||
const DType *src = (const DType*)src_;
|
|
||||||
DType *dst = (DType*)dst_;
|
|
||||||
for (int i = 0; i < len; ++i) {
|
|
||||||
OP::Reduce(dst[i], src[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
|
||||||
void Init(int argc, char *argv[]) {
|
/*!
|
||||||
engine::Init(argc, argv);
|
* \brief intialize the rabit module, call this once function before using anything
|
||||||
}
|
* \param argc number of arguments in argv
|
||||||
void Finalize(void) {
|
* \param argv the array of input arguments
|
||||||
engine::Finalize();
|
*/
|
||||||
}
|
inline void Init(int argc, char *argv[]);
|
||||||
|
/*!
|
||||||
|
* \brief finalize the rabit engine, call this function after you finished all jobs
|
||||||
|
*/
|
||||||
|
inline void Finalize(void);
|
||||||
/*! \brief get rank of current process */
|
/*! \brief get rank of current process */
|
||||||
inline int GetRank(void) {
|
inline int GetRank(void);
|
||||||
return engine::GetEngine()->GetRank();
|
|
||||||
}
|
|
||||||
/*! \brief get total number of process */
|
/*! \brief get total number of process */
|
||||||
int GetWorldSize(void) {
|
inline int GetWorldSize(void);
|
||||||
return engine::GetEngine()->GetWorldSize();
|
|
||||||
}
|
|
||||||
/*! \brief get name of processor */
|
/*! \brief get name of processor */
|
||||||
std::string GetProcessorName(void) {
|
inline std::string GetProcessorName(void);
|
||||||
return engine::GetEngine()->GetHost();
|
|
||||||
}
|
|
||||||
/*!
|
/*!
|
||||||
* \brief broadcast an std::string to all others from root
|
* \brief broadcast an std::string to all others from root
|
||||||
* \param sendrecv_data the pointer to send or recive buffer,
|
* \param sendrecv_data the pointer to send or recive buffer,
|
||||||
@ -67,15 +47,7 @@ std::string GetProcessorName(void) {
|
|||||||
* and string will be resized to correct length
|
* and string will be resized to correct length
|
||||||
* \param root the root of process
|
* \param root the root of process
|
||||||
*/
|
*/
|
||||||
inline void Bcast(std::string *sendrecv_data, int root) {
|
inline void Bcast(std::string *sendrecv_data, int root);
|
||||||
engine::IEngine *e = engine::GetEngine();
|
|
||||||
unsigned len = static_cast<unsigned>(sendrecv_data->length());
|
|
||||||
e->Broadcast(&len, sizeof(len), root);
|
|
||||||
sendrecv_data->resize(len);
|
|
||||||
if (len != 0) {
|
|
||||||
e->Broadcast(&(*sendrecv_data)[0], len, root);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf
|
* \brief perform in-place allreduce, on sendrecvbuf
|
||||||
* this function is NOT thread-safe
|
* this function is NOT thread-safe
|
||||||
@ -90,9 +62,7 @@ inline void Bcast(std::string *sendrecv_data, int root) {
|
|||||||
* \tparam DType type of data
|
* \tparam DType type of data
|
||||||
*/
|
*/
|
||||||
template<typename OP, typename DType>
|
template<typename OP, typename DType>
|
||||||
inline void AllReduce(DType *sendrecvbuf, size_t count) {
|
inline void AllReduce(DType *sendrecvbuf, size_t count);
|
||||||
engine::GetEngine()->AllReduce(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>);
|
|
||||||
}
|
|
||||||
/*!
|
/*!
|
||||||
* \brief load latest check point
|
* \brief load latest check point
|
||||||
* \param p_model pointer to the model
|
* \param p_model pointer to the model
|
||||||
@ -110,9 +80,7 @@ inline void AllReduce(DType *sendrecvbuf, size_t count) {
|
|||||||
*
|
*
|
||||||
* \sa CheckPoint, VersionNumber
|
* \sa CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
inline int LoadCheckPoint(utils::ISerializable *p_model) {
|
inline int LoadCheckPoint(utils::ISerializable *p_model);
|
||||||
return engine::GetEngine()->LoadCheckPoint(p_model);
|
|
||||||
}
|
|
||||||
/*!
|
/*!
|
||||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||||
* every time we call check point, there is a version number which will increase by one
|
* every time we call check point, there is a version number which will increase by one
|
||||||
@ -120,16 +88,14 @@ inline int LoadCheckPoint(utils::ISerializable *p_model) {
|
|||||||
* \param p_model pointer to the model
|
* \param p_model pointer to the model
|
||||||
* \sa LoadCheckPoint, VersionNumber
|
* \sa LoadCheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
inline void CheckPoint(const utils::ISerializable &model) {
|
inline void CheckPoint(const utils::ISerializable &model);
|
||||||
engine::GetEngine()->CheckPoint(model);
|
|
||||||
}
|
|
||||||
/*!
|
/*!
|
||||||
* \return version number of current stored model,
|
* \return version number of current stored model,
|
||||||
* which means how many calls to CheckPoint we made so far
|
* which means how many calls to CheckPoint we made so far
|
||||||
* \sa LoadCheckPoint, CheckPoint
|
* \sa LoadCheckPoint, CheckPoint
|
||||||
*/
|
*/
|
||||||
inline int VersionNumber(void) {
|
inline int VersionNumber(void);
|
||||||
return engine::GetEngine()->VersionNumber();
|
|
||||||
}
|
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
// implementation of template functions
|
||||||
|
#include "./rabit-inl.h"
|
||||||
#endif // RABIT_ALLREDUCE_H
|
#endif // RABIT_ALLREDUCE_H
|
||||||
|
|||||||
2
test/.gitignore
vendored
Normal file
2
test/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
*.mpi
|
||||||
|
test_*
|
||||||
@ -4,26 +4,31 @@ export MPICXX = mpicxx
|
|||||||
export LDFLAGS= -pthread -lm
|
export LDFLAGS= -pthread -lm
|
||||||
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../src
|
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../src
|
||||||
|
|
||||||
ifeq ($(no_omp),1)
|
|
||||||
CFLAGS += -DDISABLE_OPENMP
|
|
||||||
else
|
|
||||||
CFLAGS += -fopenmp
|
|
||||||
endif
|
|
||||||
|
|
||||||
# specify tensor path
|
# specify tensor path
|
||||||
BIN = test_allreduce test_recover test_model_recover
|
BIN = test_allreduce test_recover test_model_recover
|
||||||
OBJ = engine_base.o engine_robust.o engine.o
|
# objectives that makes up rabit library
|
||||||
|
RABIT_OBJ = allreduce_base.o allreduce_robust.o engine.o
|
||||||
|
MPIOBJ = engine_mpi.o
|
||||||
|
|
||||||
|
OBJ = $(RABIT_OBJ) test_allreduce.o test_recover.o test_model_recover.o
|
||||||
|
MPIBIN = test_allreduce.mpi
|
||||||
.PHONY: clean all
|
.PHONY: clean all
|
||||||
|
|
||||||
all: $(BIN) $(MPIBIN)
|
all: $(BIN) $(MPIBIN)
|
||||||
|
|
||||||
engine_tcp.o: ../src/engine_tcp.cpp ../src/*.h
|
allreduce_base.o: ../src/allreduce_base.cc ../src/*.h
|
||||||
engine_base.o: ../src/engine_base.cc ../src/*.h
|
|
||||||
engine.o: ../src/engine.cc ../src/*.h
|
engine.o: ../src/engine.cc ../src/*.h
|
||||||
engine_robust.o: ../src/engine_robust.cc ../src/*.h
|
allreduce_robust.o: ../src/allreduce_robust.cc ../src/*.h
|
||||||
test_allreduce: test_allreduce.cpp ../src/*.h $(OBJ)
|
engine_mpi.o: ../src/engine_mpi.cc
|
||||||
test_recover: test_recover.cpp ../src/*.h $(OBJ)
|
test_allreduce.o: test_allreduce.cpp ../src/*.h
|
||||||
test_model_recover: test_model_recover.cpp ../src/*.h $(OBJ)
|
test_recover.o: test_recover.cpp ../src/*.h
|
||||||
|
test_model_recover.o: test_model_recover.cpp ../src/*.h
|
||||||
|
|
||||||
|
# we can link against MPI version to get use MPI
|
||||||
|
test_allreduce: test_allreduce.o $(RABIT_OBJ)
|
||||||
|
test_allreduce.mpi: test_allreduce.o $(MPIOBJ)
|
||||||
|
test_recover: test_recover.o $(RABIT_OBJ)
|
||||||
|
test_model_recover: test_model_recover.o $(RABIT_OBJ)
|
||||||
|
|
||||||
$(BIN) :
|
$(BIN) :
|
||||||
$(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^)
|
$(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^)
|
||||||
@ -32,7 +37,10 @@ $(OBJ) :
|
|||||||
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
||||||
|
|
||||||
$(MPIBIN) :
|
$(MPIBIN) :
|
||||||
$(MPICXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^)
|
$(MPICXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^)
|
||||||
|
|
||||||
|
$(MPIOBJ) :
|
||||||
|
$(MPICXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
$(RM) $(OBJ) $(BIN) $(MPIBIN) *~ ../src/*~
|
$(RM) $(OBJ) $(BIN) $(MPIBIN) *~ ../src/*~
|
||||||
|
|||||||
@ -79,7 +79,8 @@ int main(int argc, char *argv[]) {
|
|||||||
utils::LogPrintf("[%d] !!!TestMax pass\n", rank);
|
utils::LogPrintf("[%d] !!!TestMax pass\n", rank);
|
||||||
TestSum(mock, n);
|
TestSum(mock, n);
|
||||||
utils::LogPrintf("[%d] !!!TestSum pass\n", rank);
|
utils::LogPrintf("[%d] !!!TestSum pass\n", rank);
|
||||||
for (int i = 0; i < nproc; i += nproc / 3) {
|
int step = std::max(nproc / 3, 1);
|
||||||
|
for (int i = 0; i < nproc; i += step) {
|
||||||
TestBcast(mock, n, i);
|
TestBcast(mock, n, i);
|
||||||
}
|
}
|
||||||
utils::LogPrintf("[%d] !!!TestBcast pass\n", rank);
|
utils::LogPrintf("[%d] !!!TestBcast pass\n", rank);
|
||||||
|
|||||||
@ -132,7 +132,7 @@ int main(int argc, char *argv[]) {
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
} catch (MockException &e) {
|
} catch (MockException &e) {
|
||||||
//rabit::engine::GetEngine()->InitAfterException();
|
rabit::engine::GetEngine()->InitAfterException();
|
||||||
++ntrial;
|
++ntrial;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -115,7 +115,7 @@ int main(int argc, char *argv[]) {
|
|||||||
// reach here
|
// reach here
|
||||||
break;
|
break;
|
||||||
} catch (MockException &e) {
|
} catch (MockException &e) {
|
||||||
//rabit::engine::GetEngine()->InitAfterException();
|
rabit::engine::GetEngine()->InitAfterException();
|
||||||
++ntrial;
|
++ntrial;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user