change file structure

This commit is contained in:
tqchen
2014-12-20 16:19:54 -08:00
parent 77d74f6c0d
commit 925d014271
24 changed files with 84 additions and 76 deletions

6
src/README.md Normal file
View File

@@ -0,0 +1,6 @@
Source Files of Rabit
====
* This folder contains the source files of rabit library
* The library headers are in folder [include](../include)
* The .h files in this folder are internal header files that are only used by rabit and will not be seen by users

View File

@@ -13,9 +13,9 @@
#include <vector>
#include <string>
#include "./utils.h"
#include <rabit/utils.h>
#include <rabit/engine.h>
#include "./socket.h"
#include "./engine.h"
namespace MPI {
// MPI data type to be compatible with existing MPI interface

View File

@@ -9,10 +9,11 @@
#define NOMINMAX
#include <limits>
#include <utility>
#include "./io.h"
#include "./utils.h"
#include <rabit/io.h>
#include <rabit/utils.h>
#include <rabit/engine.h>
#include <rabit/rabit-inl.h>
#include "./allreduce_robust.h"
#include "./rabit.h"
namespace rabit {
namespace engine {

View File

@@ -10,7 +10,7 @@
#ifndef RABIT_ALLREDUCE_ROBUST_H
#define RABIT_ALLREDUCE_ROBUST_H
#include <vector>
#include "./engine.h"
#include <rabit/engine.h>
#include "./allreduce_base.h"
namespace rabit {

View File

@@ -9,7 +9,7 @@
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include "./engine.h"
#include <rabit/engine.h>
#include "./allreduce_base.h"
#include "./allreduce_robust.h"

View File

@@ -1,224 +0,0 @@
/*!
* \file engine.h
* \brief This file defines the core interface of allreduce library
* \author Tianqi Chen, Nacho, Tianyi
*/
#ifndef RABIT_ENGINE_H
#define RABIT_ENGINE_H
#include "./serializable.h"
namespace MPI {
/*! \brief MPI data type just to be compatible with MPI reduce function*/
class Datatype;
}
/*! \brief namespace of rabit */
namespace rabit {
/*! \brief core interface of engine */
namespace engine {
/*! \brief interface of core Allreduce engine */
class IEngine {
public:
/*!
* \brief Preprocessing function, that is called before AllReduce,
* used to prepare the data used by AllReduce
* \param arg additional possible argument used to invoke the preprocessor
*/
typedef void (PreprocFunction) (void *arg);
/*!
* \brief reduce function, the same form of MPI reduce function is used,
* to be compatible with MPI interface
* In all the functions, the memory is ensured to aligned to 64-bit
* which means it is OK to cast src,dst to double* int* etc
* \param src pointer to source space
* \param dst pointer to destination reduction
* \param count total number of elements to be reduced(note this is total number of elements instead of bytes)
* the definition of reduce function should be type aware
* \param dtype the data type object, to be compatible with MPI reduce
*/
typedef void (ReduceFunction) (const void *src,
void *dst, int count,
const MPI::Datatype &dtype);
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
*/
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL) = 0;
/*!
* \brief broadcast data from root to all nodes
* \param sendrecvbuf_ buffer for both sending and recving data
* \param size the size of the data to be broadcasted
* \param root the root worker id to broadcast the data
*/
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) = 0;
/*!
* \brief explicitly re-init everything before calling LoadCheckPoint
* call this function when IEngine throw an exception out,
* this function is only used for test purpose
*/
virtual void InitAfterException(void) = 0;
/*!
* \brief load latest check point
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local model is needed
*
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
*
* Common usage example:
* int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce
* rabit::CheckPoint(model);
* }
*
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model = NULL) = 0;
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local state is needed
*
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
* So only CheckPoint with global_model if possible
*
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL) = 0;
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
virtual int VersionNumber(void) const = 0;
/*! \brief get rank of current node */
virtual int GetRank(void) const = 0;
/*! \brief get total number of */
virtual int GetWorldSize(void) const = 0;
/*! \brief get the host name of current node */
virtual std::string GetHost(void) const = 0;
/*!
* \brief print the msg in the tracker,
* this function can be used to communicate the information of the progress to
* the user who monitors the tracker
* \param msg message to be printed in the tracker
*/
virtual void TrackerPrint(const std::string &msg) = 0;
};
/*! \brief intiialize the engine module */
void Init(int argc, char *argv[]);
/*! \brief finalize engine module */
void Finalize(void);
/*! \brief singleton method to get engine */
IEngine *GetEngine(void);
/*! \brief namespace that contains staffs to be compatible with MPI */
namespace mpi {
/*!\brief enum of all operators */
enum OpType {
kMax, kMin, kSum, kBitwiseOR
};
/*!\brief enum of supported data types */
enum DataType {
kInt,
kUInt,
kDouble,
kFloat
};
} // namespace mpi
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this is an internal function used by rabit to be able to compile with MPI
* do not use this function directly
* \param sendrecvbuf buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \param dtype the data type
* \param op the reduce operator type
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function *
*/
void Allreduce_(void *sendrecvbuf,
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL);
/*!
* \brief handle for customized reducer, used to handle customized reduce
* this class is mainly created for compatiblity issue with MPI's customized reduce
*/
class ReduceHandle {
public:
// constructor
ReduceHandle(void);
// destructor
~ReduceHandle(void);
/*!
* \brief initialize the reduce function,
* with the type the reduce function need to deal with
* the reduce function MUST be communicative
*/
void Init(IEngine::ReduceFunction redfunc, size_t type_nbytes);
/*!
* \brief customized in-place all reduce operation
* \param sendrecvbuf the in place send-recv buffer
* \param type_n4bytes unit size of the type, in terms of 4bytes
* \param count number of elements to send
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
*/
void Allreduce(void *sendrecvbuf,
size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL);
/*! \return the number of bytes occupied by the type */
static int TypeSize(const MPI::Datatype &dtype);
private:
// handle data field
void *handle_;
// handle to the type field
void *htype_;
// the created type in 4 bytes
size_t created_type_nbytes_;
};
} // namespace engine
} // namespace rabit
#endif // RABIT_ENGINE_H

View File

@@ -9,7 +9,8 @@
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include "./engine.h"
#include <rabit/engine.h>
namespace rabit {
namespace engine {
/*! \brief EmptyEngine */

View File

@@ -9,8 +9,8 @@
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include <cstdio>
#include "./engine.h"
#include "./utils.h"
#include <rabit/engine.h>
#include <rabit/utils.h>
#include <mpi.h>
namespace rabit {

132
src/io.h
View File

@@ -1,132 +0,0 @@
#ifndef RABIT_UTILS_IO_H
#define RABIT_UTILS_IO_H
#include <cstdio>
#include <vector>
#include <cstring>
#include <string>
#include "./utils.h"
#include "./serializable.h"
/*!
* \file io.h
* \brief utilities that implements different serializable interface
* \author Tianqi Chen
*/
namespace rabit {
namespace utils {
/*! \brief interface of i/o stream that support seek */
class ISeekStream: public IStream {
public:
/*! \brief seek to certain position of the file */
virtual void Seek(size_t pos) = 0;
/*! \brief tell the position of the stream */
virtual size_t Tell(void) = 0;
};
/*! \brief fixed size memory buffer */
struct MemoryFixSizeBuffer : public ISeekStream {
public:
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
: p_buffer_(reinterpret_cast<char*>(p_buffer)), buffer_size_(buffer_size) {
curr_ptr_ = 0;
}
virtual ~MemoryFixSizeBuffer(void) {}
virtual size_t Read(void *ptr, size_t size) {
utils::Assert(curr_ptr_ + size <= buffer_size_,
"read can not have position excceed buffer length");
size_t nread = std::min(buffer_size_ - curr_ptr_, size);
if (nread != 0) memcpy(ptr, p_buffer_ + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
virtual void Write(const void *ptr, size_t size) {
if (size == 0) return;
utils::Assert(curr_ptr_ + size <= buffer_size_,
"write position exceed fixed buffer size");
memcpy(p_buffer_ + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
virtual void Seek(size_t pos) {
curr_ptr_ = static_cast<size_t>(pos);
}
virtual size_t Tell(void) {
return curr_ptr_;
}
private:
/*! \brief in memory buffer */
char *p_buffer_;
/*! \brief current pointer */
size_t buffer_size_;
/*! \brief current pointer */
size_t curr_ptr_;
}; // class MemoryFixSizeBuffer
/*! \brief a in memory buffer that can be read and write as stream interface */
struct MemoryBufferStream : public ISeekStream {
public:
MemoryBufferStream(std::string *p_buffer)
: p_buffer_(p_buffer) {
curr_ptr_ = 0;
}
virtual ~MemoryBufferStream(void) {}
virtual size_t Read(void *ptr, size_t size) {
utils::Assert(curr_ptr_ <= p_buffer_->length(),
"read can not have position excceed buffer length");
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
if (nread != 0) memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
virtual void Write(const void *ptr, size_t size) {
if (size == 0) return;
if (curr_ptr_ + size > p_buffer_->length()) {
p_buffer_->resize(curr_ptr_+size);
}
memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
virtual void Seek(size_t pos) {
curr_ptr_ = static_cast<size_t>(pos);
}
virtual size_t Tell(void) {
return curr_ptr_;
}
private:
/*! \brief in memory buffer */
std::string *p_buffer_;
/*! \brief current pointer */
size_t curr_ptr_;
}; // class MemoryBufferStream
/*! \brief implementation of file i/o stream */
class FileStream : public ISeekStream {
public:
explicit FileStream(FILE *fp) : fp(fp) {}
explicit FileStream(void) {
this->fp = NULL;
}
virtual size_t Read(void *ptr, size_t size) {
return std::fread(ptr, size, 1, fp);
}
virtual void Write(const void *ptr, size_t size) {
std::fwrite(ptr, size, 1, fp);
}
virtual void Seek(size_t pos) {
std::fseek(fp, static_cast<long>(pos), SEEK_SET);
}
virtual size_t Tell(void) {
return std::ftell(fp);
}
inline void Close(void) {
if (fp != NULL){
std::fclose(fp); fp = NULL;
}
}
private:
FILE *fp;
};
} // namespace utils
} // namespace rabit
#endif

View File

@@ -1,275 +0,0 @@
/*!
* \file rabit-inl.h
* \brief implementation of inline template function for rabit interface
*
* \author Tianqi Chen
*/
#ifndef RABIT_RABIT_INL_H
#define RABIT_RABIT_INL_H
// use engine for implementation
#include "./io.h"
#include "./utils.h"
namespace rabit {
namespace engine {
namespace mpi {
// template function to translate type to enum indicator
template<typename DType>
inline DataType GetType(void);
template<>
inline DataType GetType<int>(void) {
return kInt;
}
template<>
inline DataType GetType<unsigned>(void) {
return kUInt;
}
template<>
inline DataType GetType<float>(void) {
return kFloat;
}
template<>
inline DataType GetType<double>(void) {
return kDouble;
}
} // namespace mpi
} // namespace engine
namespace op {
struct Max {
const static engine::mpi::OpType kType = engine::mpi::kMax;
template<typename DType>
inline static void Reduce(DType &dst, const DType &src) {
if (dst < src) dst = src;
}
};
struct Min {
const static engine::mpi::OpType kType = engine::mpi::kMin;
template<typename DType>
inline static void Reduce(DType &dst, const DType &src) {
if (dst > src) dst = src;
}
};
struct Sum {
const static engine::mpi::OpType kType = engine::mpi::kSum;
template<typename DType>
inline static void Reduce(DType &dst, const DType &src) {
dst += src;
}
};
struct BitOR {
const static engine::mpi::OpType kType = engine::mpi::kBitwiseOR;
template<typename DType>
inline static void Reduce(DType &dst, const DType &src) {
dst |= src;
}
};
template<typename OP, typename DType>
inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
const DType *src = (const DType*)src_;
DType *dst = (DType*)dst_;
for (int i = 0; i < len; ++i) {
OP::Reduce(dst[i], src[i]);
}
}
} // namespace op
// intialize the rabit engine
inline void Init(int argc, char *argv[]) {
engine::Init(argc, argv);
}
// finalize the rabit engine
inline void Finalize(void) {
engine::Finalize();
}
// get the rank of current process
inline int GetRank(void) {
return engine::GetEngine()->GetRank();
}
// the the size of the world
inline int GetWorldSize(void) {
return engine::GetEngine()->GetWorldSize();
}
// get the name of current processor
inline std::string GetProcessorName(void) {
return engine::GetEngine()->GetHost();
}
// broadcast data to all other nodes from root
inline void Broadcast(void *sendrecv_data, size_t size, int root) {
engine::GetEngine()->Broadcast(sendrecv_data, size, root);
}
template<typename DType>
inline void Broadcast(std::vector<DType> *sendrecv_data, int root) {
size_t size = sendrecv_data->size();
Broadcast(&size, sizeof(size), root);
if (sendrecv_data->size() != size) {
sendrecv_data->resize(size);
}
if (size != 0) {
Broadcast(&(*sendrecv_data)[0], size * sizeof(DType), root);
}
}
inline void Broadcast(std::string *sendrecv_data, int root) {
size_t size = sendrecv_data->length();
Broadcast(&size, sizeof(size), root);
if (sendrecv_data->length() != size) {
sendrecv_data->resize(size);
}
if (size != 0) {
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root);
}
}
// perform inplace Allreduce
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>,
engine::mpi::GetType<DType>(), OP::kType, prepare_fun, prepare_arg);
}
// C++11 support for lambda prepare function
#if __cplusplus >= 201103L
inline void InvokeLambda_(void *fun) {
(*static_cast<std::function<void()>*>(fun))();
}
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> prepare_fun) {
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>,
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda_, &prepare_fun);
}
#endif // C++11
// print message to the tracker
inline void TrackerPrint(const std::string &msg) {
engine::GetEngine()->TrackerPrint(msg);
}
#ifndef RABIT_STRICT_CXX98_
inline void TrackerPrintf(const char *fmt, ...) {
const int kPrintBuffer = 1 << 10;
std::string msg(kPrintBuffer, '\0');
va_list args;
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
TrackerPrint(msg);
}
#endif
// load latest check point
inline int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model) {
return engine::GetEngine()->LoadCheckPoint(global_model, local_model);
}
// checkpoint the model, meaning we finished a stage of execution
inline void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model) {
engine::GetEngine()->CheckPoint(global_model, local_model);
}
// return the version number of currently stored model
inline int VersionNumber(void) {
return engine::GetEngine()->VersionNumber();
}
// ---------------------------------
// Code to handle customized Reduce
// ---------------------------------
// function to perform reduction for Reducer
template<typename DType>
inline void ReducerFunc_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
const size_t kUnit = sizeof(DType);
const char *psrc = reinterpret_cast<const char*>(src_);
char *pdst = reinterpret_cast<char*>(dst_);
DType tdst, tsrc;
for (size_t i = 0; i < len_; ++i) {
// use memcpy to avoid alignment issue
std::memcpy(&tdst, pdst + i * kUnit, sizeof(tdst));
std::memcpy(&tsrc, psrc + i * kUnit, sizeof(tsrc));
tdst.Reduce(tsrc);
std::memcpy(pdst + i * kUnit, &tdst, sizeof(tdst));
}
}
template<typename DType>
inline Reducer<DType>::Reducer(void) {
this->handle_.Init(ReducerFunc_<DType>, sizeof(DType));
}
template<typename DType>
inline void Reducer<DType>::Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
handle_.Allreduce(sendrecvbuf, sizeof(DType), count, prepare_fun, prepare_arg);
}
// function to perform reduction for SerializeReducer
template<typename DType>
inline void SerializeReducerFunc_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
int nbytes = engine::ReduceHandle::TypeSize(dtype);
// temp space
DType tsrc, tdst;
for (int i = 0; i < len_; ++i) {
utils::MemoryFixSizeBuffer fsrc((char*)(src_) + i * nbytes, nbytes);
utils::MemoryFixSizeBuffer fdst((char*)(dst_) + i * nbytes, nbytes);
tsrc.Load(fsrc);
tdst.Load(fdst);
// govern const check
tdst.Reduce(static_cast<const DType &>(tsrc), nbytes);
fdst.Seek(0);
tdst.Save(fdst);
}
}
template<typename DType>
inline SerializeReducer<DType>::SerializeReducer(void) {
handle_.Init(SerializeReducerFunc_<DType>, sizeof(DType));
}
// closure to call Allreduce
template<typename DType>
struct SerializeReduceClosure {
DType *sendrecvobj;
size_t max_nbyte, count;
void (*prepare_fun)(void *arg);
void *prepare_arg;
std::string *p_buffer;
// invoke the closure
inline void Run(void) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
for (size_t i = 0; i < count; ++i) {
utils::MemoryFixSizeBuffer fs(BeginPtr(*p_buffer) + i * max_nbyte, max_nbyte);
sendrecvobj[i].Save(fs);
}
}
inline static void Invoke(void *c) {
static_cast<SerializeReduceClosure<DType>*>(c)->Run();
}
};
template<typename DType>
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
size_t max_nbyte, size_t count,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
buffer_.resize(max_nbyte * count);
// setup closure
SerializeReduceClosure<DType> c;
c.sendrecvobj = sendrecvobj; c.max_nbyte = max_nbyte; c.count = count;
c.prepare_fun = prepare_fun; c.prepare_arg = prepare_arg; c.p_buffer = &buffer_;
// invoke here
handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count,
SerializeReduceClosure<DType>::Invoke, &c);
for (size_t i = 0; i < count; ++i) {
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte);
sendrecvobj[i].Load(fs);
}
}
#if __cplusplus >= 201103L
template<typename DType>
inline void Reducer<DType>::Allreduce(DType *sendrecvbuf, size_t count,
std::function<void()> prepare_fun) {
this->Allreduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun);
}
template<typename DType>
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
size_t max_nbytes, size_t count,
std::function<void()> prepare_fun) {
this->Allreduce(sendrecvobj, max_nbytes, count, InvokeLambda_, &prepare_fun);
}
#endif
} // namespace rabit
#endif

View File

@@ -1,286 +0,0 @@
#ifndef RABIT_RABIT_H
#define RABIT_RABIT_H
/*!
* \file rabit.h
* \brief This file defines unified Allreduce/Broadcast interface of rabit
* The actual implementation is redirected to rabit engine
* Code only using this header can also compiled with MPI Allreduce(with no fault recovery),
*
* rabit.h and serializable.h is all the user need to use rabit interface
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#include <string>
#include <vector>
// optionally support of lambda function in C++11, if available
#if __cplusplus >= 201103L
#include <functional>
#endif // C++11
// contains definition of ISerializable
#include "./serializable.h"
// engine definition of rabit, defines internal implementation
// to use rabit interface, there is no need to read engine.h rabit.h and serializable.h
// is suffice to use the interface
#include "./engine.h"
/*! \brief namespace of rabit */
namespace rabit {
/*! \brief namespace of operator */
namespace op {
/*! \brief maximum value */
struct Max;
/*! \brief minimum value */
struct Min;
/*! \brief perform sum */
struct Sum;
/*! \brief perform bitwise OR */
struct BitOR;
} // namespace op
/*!
* \brief intialize the rabit module, call this once function before using anything
* \param argc number of arguments in argv
* \param argv the array of input arguments
*/
inline void Init(int argc, char *argv[]);
/*!
* \brief finalize the rabit engine, call this function after you finished all jobs
*/
inline void Finalize(void);
/*! \brief get rank of current process */
inline int GetRank(void);
/*! \brief get total number of process */
inline int GetWorldSize(void);
/*! \brief whether rabit env is in distributed mode */
inline bool IsDistributed(void) {
return GetWorldSize() != 1;
}
/*! \brief get name of processor */
inline std::string GetProcessorName(void);
/*!
* \brief print the msg to the tracker,
* this function can be used to communicate the information of the progress to
* the user who monitors the tracker
* \param msg, the message to be printed
*/
inline void TrackerPrint(const std::string &msg);
#ifndef RABIT_STRICT_CXX98_
/*!
* \brief print the msg to the tracker, this function may not be available
* in very strict c++98 compilers, but is available most of the time
* this function can be used to communicate the information of the progress to
* the user who monitors the tracker
* \param fmt the format string
*/
inline void TrackerPrintf(const char *fmt, ...);
#endif
/*!
* \brief broadcast an memory region to all others from root
* Example: int a = 1; Broadcast(&a, sizeof(a), root);
* \param sendrecv_data the pointer to send or recive buffer,
* \param size the size of the data
* \param root the root of process
*/
inline void Broadcast(void *sendrecv_data, size_t size, int root);
/*!
* \brief broadcast an std::vector<DType> to all others from root
* \param sendrecv_data the pointer to send or recive vector,
* for receiver, the vector does not need to be pre-allocated
* \param root the root of process
* \tparam DType the data type stored in vector, have to be simple data type
* that can be directly send by sending the sizeof(DType) data
*/
template<typename DType>
inline void Broadcast(std::vector<DType> *sendrecv_data, int root);
/*!
* \brief broadcast an std::string to all others from root
* \param sendrecv_data the pointer to send or recive vector,
* for receiver, the vector does not need to be pre-allocated
* \param root the root of process
*/
inline void Broadcast(std::string *sendrecv_data, int root);
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe
* Example Usage: the following code gives sum of the result
* vector<int> data(10);
* ...
* Allreduce<op::Sum>(&data[0], data.size());
* ...
* \param sendrecvbuf buffer for both sending and recving data
* \param count number of elements to be reduced
* \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
* \tparam OP see namespace op, reduce operator
* \tparam DType type of data
*/
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg) = NULL,
void *prepare_arg = NULL);
// C++11 support for lambda prepare function
#if __cplusplus >= 201103L
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* with a prepare function specified by lambda function
* Example Usage: the following code gives sum of the result
* vector<int> data(10);
* ...
* Allreduce<op::Sum>(&data[0], data.size(), [&]() {
* for (int i = 0; i < 10; ++i) {
* data[i] = i;
* }
* });
* ...
* \param sendrecvbuf buffer for both sending and recving data
* \param count number of elements to be reduced
* \param prepare_func Lazy lambda preprocessing function, prepare_fun() will be invoked
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \tparam OP see namespace op, reduce operator
* \tparam DType type of data
*/
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> prepare_fun);
#endif // C++11
/*!
* \brief load latest check point
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local model is needed
*
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
*
* Common usage example:
* int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce
* rabit::CheckPoint(model);
* }
*
* \sa CheckPoint, VersionNumber
*/
inline int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model = NULL);
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local state is needed
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
* So only CheckPoint with global_model if possible
* \sa LoadCheckPoint, VersionNumber
*/
inline void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL);
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
inline int VersionNumber(void);
// ----- extensions that allow customized reducer ------
// helper class to do customized reduce, user do not need to know the type
namespace engine {
class ReduceHandle;
} // namespace engine
/*!
* \brief template class to make customized reduce and all reduce easy
* Do not use reducer directly in the function you call Finalize, because the destructor can happen after Finalize
* \tparam DType data type that to be reduced
* DType must be a struct, with no pointer, and contains a function Reduce(const DType &d);
*/
template<typename DType>
class Reducer {
public:
Reducer(void);
/*!
* \brief customized in-place all reduce operation
* \param sendrecvbuf the in place send-recv buffer
* \param count number of elements to be reduced
* \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
*/
inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg) = NULL,
void *prepare_arg = NULL);
#if __cplusplus >= 201103L
/*!
* \brief customized in-place all reduce operation, with lambda function as preprocessor
* \param sendrecvbuf pointer to the array of objects to be reduced
* \param count number of elements to be reduced
* \param prepare_fun lambda function executed to prepare the data, if necessary
*/
inline void Allreduce(DType *sendrecvbuf, size_t count,
std::function<void()> prepare_fun);
#endif
private:
/*! \brief function handle to do reduce */
engine::ReduceHandle handle_;
};
/*!
* \brief template class to make customized reduce,
* this class defines complex reducer handles all the data structure that can be
* serialized/deserialzed into fixed size buffer
* Do not use reducer directly in the function you call Finalize, because the destructor can happen after Finalize
*
* \tparam DType data type that to be reduced, DType must contain following functions:
* (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &d);
*/
template<typename DType>
class SerializeReducer {
public:
SerializeReducer(void);
/*!
* \brief customized in-place all reduce operation
* \param sendrecvobj pointer to the array of objects to be reduced
* \param max_nbyte maximum amount of memory needed to serialize each object
* this includes budget limit for intermediate and final result
* \param count number of elements to be reduced
* \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
*/
inline void Allreduce(DType *sendrecvobj,
size_t max_nbyte, size_t count,
void (*prepare_fun)(void *arg) = NULL,
void *prepare_arg = NULL);
// C++11 support for lambda prepare function
#if __cplusplus >= 201103L
/*!
* \brief customized in-place all reduce operation, with lambda function as preprocessor
* \param sendrecvobj pointer to the array of objects to be reduced
* \param max_nbyte maximum amount of memory needed to serialize each object
* this includes budget limit for intermediate and final result
* \param count number of elements to be reduced
* \param prepare_fun lambda function executed to prepare the data, if necessary
*/
inline void Allreduce(DType *sendrecvobj,
size_t max_nbyte, size_t count,
std::function<void()> prepare_fun);
#endif
private:
/*! \brief function handle to do reduce */
engine::ReduceHandle handle_;
/*! \brief temporal buffer used to do reduce*/
std::string buffer_;
};
} // namespace rabit
// implementation of template functions
#include "./rabit-inl.h"
#endif // RABIT_ALLREDUCE_H

View File

@@ -1,99 +0,0 @@
#ifndef RABIT_SERIALIZABLE_H
#define RABIT_SERIALIZABLE_H
#include <vector>
#include <string>
#include "./utils.h"
/*!
* \file serializable.h
* \brief defines serializable interface of rabit
* \author Tianqi Chen
*/
namespace rabit {
/*!
* \brief interface of stream I/O, used by ISerializable
* \sa ISerializable
*/
class IStream {
public:
/*!
* \brief read data from stream
* \param ptr pointer to memory buffer
* \param size size of block
* \return usually is the size of data readed
*/
virtual size_t Read(void *ptr, size_t size) = 0;
/*!
* \brief write data to stream
* \param ptr pointer to memory buffer
* \param size size of block
*/
virtual void Write(const void *ptr, size_t size) = 0;
/*! \brief virtual destructor */
virtual ~IStream(void) {}
public:
// helper functions to write various of data structures
/*!
* \brief binary serialize a vector
* \param vec vector to be serialized
*/
template<typename T>
inline void Write(const std::vector<T> &vec) {
uint64_t sz = static_cast<uint64_t>(vec.size());
this->Write(&sz, sizeof(sz));
if (sz != 0) {
this->Write(&vec[0], sizeof(T) * sz);
}
}
/*!
* \brief binary load a vector
* \param out_vec vector to be loaded
* \return whether load is successfull
*/
template<typename T>
inline bool Read(std::vector<T> *out_vec) {
uint64_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_vec->resize(sz);
if (sz != 0) {
if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false;
}
return true;
}
/*!
* \brief binary serialize a string
* \param str the string to be serialized
*/
inline void Write(const std::string &str) {
uint64_t sz = static_cast<uint64_t>(str.length());
this->Write(&sz, sizeof(sz));
if (sz != 0) {
this->Write(&str[0], sizeof(char) * sz);
}
}
/*!
* \brief binary load a string
* \param out_str string to be loaded
* \return whether load is successful
*/
inline bool Read(std::string *out_str) {
uint64_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_str->resize(sz);
if (sz != 0) {
if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false;
}
return true;
}
};
/*! \brief interface of se*/
class ISerializable {
public:
/*! \brief load the model from file */
virtual void Load(IStream &fi) = 0;
/*! \brief save the model to the stream*/
virtual void Save(IStream &fo) const = 0;
};
} // namespace rabit
#endif

View File

@@ -21,7 +21,7 @@
#endif
#include <string>
#include <cstring>
#include "./utils.h"
#include <rabit/utils.h>
#if defined(_WIN32)
typedef int ssize_t;

View File

@@ -1,23 +0,0 @@
/*!
* \file timer.h
* \brief This file defines the utils for timing
* \author Tianqi Chen, Nacho, Tianyi
*/
#ifndef RABIT_TIMER_H
#define RABIT_TIMER_H
#include <time.h>
#include "./utils.h"
namespace rabit {
namespace utils {
/*!
* \brief return time in seconds
*/
inline double GetTime(void) {
timespec ts;
utils::Check(clock_gettime(CLOCK_REALTIME, &ts) == 0, "failed to get time");
return static_cast<double>(ts.tv_sec) + static_cast<double>(ts.tv_nsec) * 1e-9;
}
}
}
#endif

View File

@@ -1,190 +0,0 @@
#ifndef RABIT_UTILS_H_
#define RABIT_UTILS_H_
/*!
* \file utils.h
* \brief simple utils to support the code
* \author Tianqi Chen
*/
#define _CRT_SECURE_NO_WARNINGS
#include <cstdio>
#include <string>
#include <cstdlib>
#include <vector>
#ifndef RABIT_STRICT_CXX98_
#include <cstdarg>
#endif
#if !defined(__GNUC__)
#define fopen64 std::fopen
#endif
#ifdef _MSC_VER
// NOTE: sprintf_s is not equivalent to snprintf,
// they are equivalent when success, which is sufficient for our case
#define snprintf sprintf_s
#define vsnprintf vsprintf_s
#else
#ifdef _FILE_OFFSET_BITS
#if _FILE_OFFSET_BITS == 32
#pragma message ("Warning: FILE OFFSET BITS defined to be 32 bit")
#endif
#endif
#ifdef __APPLE__
#define off64_t off_t
#define fopen64 std::fopen
#endif
extern "C" {
#include <sys/types.h>
}
#endif
#ifdef _MSC_VER
typedef unsigned char uint8_t;
typedef unsigned short int uint16_t;
typedef unsigned int uint32_t;
typedef unsigned long uint64_t;
typedef long int64_t;
#else
#include <inttypes.h>
#endif
namespace rabit {
/*! \brief namespace for helper utils of the project */
namespace utils {
/*! \brief error message buffer length */
const int kPrintBuffer = 1 << 12;
#ifndef RABIT_CUSTOMIZE_MSG_
/*!
* \brief handling of Assert error, caused by in-apropriate input
* \param msg error message
*/
inline void HandleAssertError(const char *msg) {
fprintf(stderr, "AssertError:%s\n", msg);
exit(-1);
}
/*!
* \brief handling of Check error, caused by in-apropriate input
* \param msg error message
*/
inline void HandleCheckError(const char *msg) {
fprintf(stderr, "%s\n", msg);
exit(-1);
}
inline void HandlePrint(const char *msg) {
printf("%s", msg);
}
inline void HandleLogPrint(const char *msg) {
fprintf(stderr, "%s", msg);
fflush(stderr);
}
#else
#ifndef RABIT_STRICT_CXX98_
// include declarations, some one must implement this
void HandleAssertError(const char *msg);
void HandleCheckError(const char *msg);
void HandlePrint(const char *msg);
#endif
#endif
#ifdef RABIT_STRICT_CXX98_
// these function pointers are to be assigned
extern "C" void (*Printf)(const char *fmt, ...);
extern "C" int (*SPrintf)(char *buf, size_t size, const char *fmt, ...);
extern "C" void (*Assert)(int exp, const char *fmt, ...);
extern "C" void (*Check)(int exp, const char *fmt, ...);
extern "C" void (*Error)(const char *fmt, ...);
#else
/*! \brief printf, print message to the console */
inline void Printf(const char *fmt, ...) {
std::string msg(kPrintBuffer, '\0');
va_list args;
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
HandlePrint(msg.c_str());
}
/*! \brief portable version of snprintf */
inline int SPrintf(char *buf, size_t size, const char *fmt, ...) {
va_list args;
va_start(args, fmt);
int ret = vsnprintf(buf, size, fmt, args);
va_end(args);
return ret;
}
/*! \brief assert an condition is true, use this to handle debug information */
inline void Assert(bool exp, const char *fmt, ...) {
if (!exp) {
std::string msg(kPrintBuffer, '\0');
va_list args;
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
HandleAssertError(msg.c_str());
}
}
/*!\brief same as assert, but this is intended to be used as message for user*/
inline void Check(bool exp, const char *fmt, ...) {
if (!exp) {
std::string msg(kPrintBuffer, '\0');
va_list args;
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
HandleCheckError(msg.c_str());
}
}
/*! \brief report error message, same as check */
inline void Error(const char *fmt, ...) {
{
std::string msg(kPrintBuffer, '\0');
va_list args;
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
HandleCheckError(msg.c_str());
}
}
#endif
/*! \brief replace fopen, report error when the file open fails */
inline std::FILE *FopenCheck(const char *fname, const char *flag) {
std::FILE *fp = fopen64(fname, flag);
Check(fp != NULL, "can not open file \"%s\"\n", fname);
return fp;
}
} // namespace utils
// easy utils that can be directly acessed in xgboost
/*! \brief get the beginning address of a vector */
template<typename T>
inline T *BeginPtr(std::vector<T> &vec) {
if (vec.size() == 0) {
return NULL;
} else {
return &vec[0];
}
}
/*! \brief get the beginning address of a vector */
template<typename T>
inline const T *BeginPtr(const std::vector<T> &vec) {
if (vec.size() == 0) {
return NULL;
} else {
return &vec[0];
}
}
inline char* BeginPtr(std::string &str) {
if (str.length() == 0) return NULL;
return &str[0];
}
inline const char* BeginPtr(const std::string &str) {
if (str.length() == 0) return NULL;
return &str[0];
}
} // namespace rabit
#endif // RABIT_UTILS_H_