lint and travis

This commit is contained in:
tqchen 2015-07-03 15:15:11 -07:00
parent ceedf4ea96
commit 3cc49ad0e8
27 changed files with 423 additions and 296 deletions

1
.gitignore vendored
View File

@ -34,3 +34,4 @@
*tmp* *tmp*
*.rabit *.rabit
*.mock *.mock
dmlc-core

48
.travis.yml Normal file
View File

@ -0,0 +1,48 @@
# disable sudo to use container based build
sudo: false
# Use Build Matrix to do lint and build seperately
env:
matrix:
- TASK=lint LINT_LANG=cpp
- TASK=lint LINT_LANG=python
- TASK=doc
- TASK=build CXX=g++
# dependent apt packages
addons:
apt:
packages:
- doxygen
- wget
- git
- libcurl4-openssl-dev
- unzip
before_install:
- git clone https://github.com/dmlc/dmlc-core
- export TRAVIS=dmlc-core/scripts/travis/
- source ${TRAVIS}/travis_setup_env.sh
install:
- pip install cpplint pylint --user `whoami`
script: scripts/travis_script.sh
before_cache:
- ${TRAVIS}/travis_before_cache.sh
cache:
directories:
- ${HOME}/.cache/usr
notifications:
# Emails are sent to the committer's git-configured email address by default,
email:
on_success: change
on_failure: always

View File

@ -3,14 +3,18 @@ export CXX = g++
endif endif
export MPICXX = mpicxx export MPICXX = mpicxx
export LDFLAGS= -Llib -lrt export LDFLAGS= -Llib -lrt
export WARNFLAGS= -Wall -Wextra -Wno-unused-parameter -Wno-unknown-pragmas export WARNFLAGS= -Wall -Wextra -Wno-unused-parameter -Wno-unknown-pragmas
export CFLAGS = -O3 -msse2 $(WARNFLAGS) export CFLAGS = -O3 -msse2 $(WARNFLAGS)
ifndef WITH_FPIC ifndef WITH_FPIC
WITH_FPIC = 1 WITH_FPIC = 1
endif endif
ifeq ($(WITH_FPIC), 1) ifeq ($(WITH_FPIC), 1)
CFLAGS += -fPIC CFLAGS += -fPIC
endif
ifndef LINT_LANG
LINT_LANG="all"
endif endif
# build path # build path
@ -22,7 +26,9 @@ OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(B
SLIB= wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so wrapper/librabit_wrapper_mpi.so SLIB= wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so wrapper/librabit_wrapper_mpi.so
ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a lib/librabit_mock.a lib/librabit_base.a ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a lib/librabit_mock.a lib/librabit_base.a
HEADERS=src/*.h include/*.h include/rabit/*.h HEADERS=src/*.h include/*.h include/rabit/*.h
.PHONY: clean all install mpi python DMLC=dmlc-core
.PHONY: clean all install mpi python lint doc
all: lib/librabit.a lib/librabit_mock.a wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so lib/librabit_base.a all: lib/librabit.a lib/librabit_mock.a wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so lib/librabit_base.a
mpi: lib/librabit_mpi.a wrapper/librabit_wrapper_mpi.so mpi: lib/librabit_mpi.a wrapper/librabit_wrapper_mpi.so
@ -47,10 +53,10 @@ wrapper/librabit_wrapper.so: $(BPATH)/rabit_wrapper.o lib/librabit.a
wrapper/librabit_wrapper_mock.so: $(BPATH)/rabit_wrapper.o lib/librabit_mock.a wrapper/librabit_wrapper_mock.so: $(BPATH)/rabit_wrapper.o lib/librabit_mock.a
wrapper/librabit_wrapper_mpi.so: $(BPATH)/rabit_wrapper.o lib/librabit_mpi.a wrapper/librabit_wrapper_mpi.so: $(BPATH)/rabit_wrapper.o lib/librabit_mpi.a
$(OBJ) : $(OBJ) :
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) ) $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
$(MPIOBJ) : $(MPIOBJ) :
$(MPICXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) ) $(MPICXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
$(ALIB): $(ALIB):
@ -59,6 +65,12 @@ $(ALIB):
$(SLIB) : $(SLIB) :
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS) $(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS)
lint:
$(DMLC)/scripts/lint.py rabit $(LINT_LANG) src include wrapper
doc:
cd include; doxygen ../doc/Doxyfile; cd -
clean: clean:
$(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~ wrapper/*~ $(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~ wrapper/*~

View File

@ -101,8 +101,8 @@ FILE_PATTERNS =
RECURSIVE = NO RECURSIVE = NO
EXCLUDE = EXCLUDE =
EXCLUDE_SYMLINKS = NO EXCLUDE_SYMLINKS = NO
EXCLUDE_PATTERNS = *-inl.hpp EXCLUDE_PATTERNS = *-inl.hpp
EXCLUDE_SYMBOLS = EXCLUDE_SYMBOLS =
EXAMPLE_PATH = EXAMPLE_PATH =
EXAMPLE_PATTERNS = EXAMPLE_PATTERNS =
EXAMPLE_RECURSIVE = NO EXAMPLE_RECURSIVE = NO

View File

@ -1,4 +1,4 @@
#!/bin/bash #!/bin/bash
cd ../include cd ../include
doxygen ../doc/Doxyfile doxygen ../doc/Doxyfile
cd ../doc cd ../doc

View File

@ -14,6 +14,7 @@
// include uint64_t only to make io standalone // include uint64_t only to make io standalone
#ifdef _MSC_VER #ifdef _MSC_VER
/*! \brief uint64 */
typedef unsigned __int64 uint64_t; typedef unsigned __int64 uint64_t;
#else #else
#include <inttypes.h> #include <inttypes.h>
@ -24,7 +25,7 @@ namespace dmlc {
/*! /*!
* \brief interface of stream I/O for serialization * \brief interface of stream I/O for serialization
*/ */
class Stream { class Stream { // NOLINT(*)
public: public:
/*! /*!
* \brief reads data from a stream * \brief reads data from a stream
@ -71,7 +72,7 @@ class Stream {
/*! /*!
* \brief writes a string * \brief writes a string
* \param str the string to be written/serialized * \param str the string to be written/serialized
*/ */
inline void Write(const std::string &str); inline void Write(const std::string &str);
/*! /*!
* \brief loads a string * \brief loads a string
@ -94,7 +95,7 @@ class SeekStream: public Stream {
* \brief generic factory function * \brief generic factory function
* create an SeekStream for read only, * create an SeekStream for read only,
* the stream will close the underlying files upon deletion * the stream will close the underlying files upon deletion
* error will be reported and the system will exit when create failed * error will be reported and the system will exit when create failed
* \param uri the uri of the input currently we support * \param uri the uri of the input currently we support
* hdfs://, s3://, and file:// by default file:// will be used * hdfs://, s3://, and file:// by default file:// will be used
* \param allow_null whether NULL can be returned, or directly report error * \param allow_null whether NULL can be returned, or directly report error
@ -107,12 +108,12 @@ class SeekStream: public Stream {
/*! \brief interface for serializable objects */ /*! \brief interface for serializable objects */
class Serializable { class Serializable {
public: public:
/*! /*!
* \brief load the model from a stream * \brief load the model from a stream
* \param fi stream where to load the model from * \param fi stream where to load the model from
*/ */
virtual void Load(Stream *fi) = 0; virtual void Load(Stream *fi) = 0;
/*! /*!
* \brief saves the model to a stream * \brief saves the model to a stream
* \param fo stream where to save the model to * \param fo stream where to save the model to
*/ */
@ -123,7 +124,7 @@ class Serializable {
* \brief input split creates that allows reading * \brief input split creates that allows reading
* of records from split of data, * of records from split of data,
* independent part that covers all the dataset * independent part that covers all the dataset
* *
* see InputSplit::Create for definition of record * see InputSplit::Create for definition of record
*/ */
class InputSplit { class InputSplit {
@ -141,7 +142,7 @@ class InputSplit {
* this is a hint so may not be enforced, * this is a hint so may not be enforced,
* but InputSplit will try adjust its internal buffer * but InputSplit will try adjust its internal buffer
* size to the hinted value * size to the hinted value
* \param chunk_size the chunk size * \param chunk_size the chunk size
*/ */
virtual void HintChunkSize(size_t chunk_size) {} virtual void HintChunkSize(size_t chunk_size) {}
/*! \brief reset the position of InputSplit to beginning */ /*! \brief reset the position of InputSplit to beginning */
@ -150,7 +151,7 @@ class InputSplit {
* \brief get the next record, the returning value * \brief get the next record, the returning value
* is valid until next call to NextRecord or NextChunk * is valid until next call to NextRecord or NextChunk
* caller can modify the memory content of out_rec * caller can modify the memory content of out_rec
* *
* For text, out_rec contains a single line * For text, out_rec contains a single line
* For recordio, out_rec contains one record content(with header striped) * For recordio, out_rec contains one record content(with header striped)
* *
@ -161,11 +162,11 @@ class InputSplit {
*/ */
virtual bool NextRecord(Blob *out_rec) = 0; virtual bool NextRecord(Blob *out_rec) = 0;
/*! /*!
* \brief get a chunk of memory that can contain multiple records, * \brief get a chunk of memory that can contain multiple records,
* the caller needs to parse the content of the resulting chunk, * the caller needs to parse the content of the resulting chunk,
* for text file, out_chunk can contain data of multiple lines * for text file, out_chunk can contain data of multiple lines
* for recordio, out_chunk can contain multiple records(including headers) * for recordio, out_chunk can contain multiple records(including headers)
* *
* This function ensures there won't be partial record in the chunk * This function ensures there won't be partial record in the chunk
* caller can modify the memory content of out_chunk, * caller can modify the memory content of out_chunk,
* the memory is valid until next call to NextRecord or NextChunk * the memory is valid until next call to NextRecord or NextChunk
@ -192,9 +193,10 @@ class InputSplit {
* List of possible types: "text", "recordio" * List of possible types: "text", "recordio"
* - "text": * - "text":
* text file, each line is treated as a record * text file, each line is treated as a record
* input split will split on \n or \r * input split will split on '\\n' or '\\r'
* - "recordio": * - "recordio":
* binary recordio file, see recordio.h * binary recordio file, see recordio.h
* \return a new input split
* \sa InputSplit::Type * \sa InputSplit::Type
*/ */
static InputSplit* Create(const char *uri, static InputSplit* Create(const char *uri,
@ -224,7 +226,7 @@ class ostream : public std::basic_ostream<char> {
* \param buffer_size internal streambuf size * \param buffer_size internal streambuf size
*/ */
explicit ostream(Stream *stream, explicit ostream(Stream *stream,
size_t buffer_size = 1 << 10) size_t buffer_size = (1 << 10))
: std::basic_ostream<char>(NULL), buf_(buffer_size) { : std::basic_ostream<char>(NULL), buf_(buffer_size) {
this->set_stream(stream); this->set_stream(stream);
} }
@ -240,7 +242,7 @@ class ostream : public std::basic_ostream<char> {
buf_.set_stream(stream); buf_.set_stream(stream);
this->rdbuf(&buf_); this->rdbuf(&buf_);
} }
private: private:
// internal streambuf // internal streambuf
class OutBuf : public std::streambuf { class OutBuf : public std::streambuf {
@ -251,7 +253,7 @@ class ostream : public std::basic_ostream<char> {
} }
// set stream to the buffer // set stream to the buffer
inline void set_stream(Stream *stream); inline void set_stream(Stream *stream);
private: private:
/*! \brief internal stream by StreamBuf */ /*! \brief internal stream by StreamBuf */
Stream *stream_; Stream *stream_;
@ -287,7 +289,7 @@ class istream : public std::basic_istream<char> {
* \param buffer_size internal buffer size * \param buffer_size internal buffer size
*/ */
explicit istream(Stream *stream, explicit istream(Stream *stream,
size_t buffer_size = 1 << 10) size_t buffer_size = (1 << 10))
: std::basic_istream<char>(NULL), buf_(buffer_size) { : std::basic_istream<char>(NULL), buf_(buffer_size) {
this->set_stream(stream); this->set_stream(stream);
} }
@ -325,7 +327,7 @@ class istream : public std::basic_istream<char> {
Stream *stream_; Stream *stream_;
/*! \brief how many bytes we read so far */ /*! \brief how many bytes we read so far */
size_t bytes_read_; size_t bytes_read_;
/*! \brief internal buffer */ /*! \brief internal buffer */
std::vector<char> buffer_; std::vector<char> buffer_;
// override underflow // override underflow
inline int_type underflow(); inline int_type underflow();
@ -402,7 +404,7 @@ inline int ostream::OutBuf::overflow(int c) {
// implementations for istream // implementations for istream
inline void istream::InBuf::set_stream(Stream *stream) { inline void istream::InBuf::set_stream(Stream *stream) {
stream_ = stream; stream_ = stream;
this->setg(&buffer_[0], &buffer_[0], &buffer_[0]); this->setg(&buffer_[0], &buffer_[0], &buffer_[0]);
} }
inline int istream::InBuf::underflow() { inline int istream::InBuf::underflow() {
char *bhead = &buffer_[0]; char *bhead = &buffer_[0];

View File

@ -8,12 +8,18 @@
* rabit.h and serializable.h is all what the user needs to use the rabit interface * rabit.h and serializable.h is all what the user needs to use the rabit interface
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/ */
#ifndef RABIT_RABIT_H_ #ifndef RABIT_RABIT_H_ // NOLINT(*)
#define RABIT_RABIT_H_ #define RABIT_RABIT_H_ // NOLINT(*)
#include <string> #include <string>
#include <vector> #include <vector>
// whether or not use c++11 support
#ifndef DMLC_USE_CXX11
#define DMLC_USE_CXX11 (defined(__GXX_EXPERIMENTAL_CXX0X__) ||\
__cplusplus >= 201103L || defined(_MSC_VER))
#endif
// optionally support of lambda functions in C++11, if available // optionally support of lambda functions in C++11, if available
#if __cplusplus >= 201103L #if DMLC_USE_CXX11
#include <functional> #include <functional>
#endif // C++11 #endif // C++11
// contains definition of Serializable // contains definition of Serializable
@ -56,8 +62,8 @@ struct BitOR;
* \param argv the array of input arguments * \param argv the array of input arguments
*/ */
inline void Init(int argc, char *argv[]); inline void Init(int argc, char *argv[]);
/*! /*!
* \brief finalizes the rabit engine, call this function after you finished with all the jobs * \brief finalizes the rabit engine, call this function after you finished with all the jobs
*/ */
inline void Finalize(void); inline void Finalize(void);
/*! \brief gets rank of the current process */ /*! \brief gets rank of the current process */
@ -71,7 +77,7 @@ inline bool IsDistributed(void);
inline std::string GetProcessorName(void); inline std::string GetProcessorName(void);
/*! /*!
* \brief prints the msg to the tracker, * \brief prints the msg to the tracker,
* this function can be used to communicate progress information to * this function can be used to communicate progress information to
* the user who monitors the tracker * the user who monitors the tracker
* \param msg the message to be printed * \param msg the message to be printed
*/ */
@ -89,7 +95,7 @@ inline void TrackerPrintf(const char *fmt, ...);
/*! /*!
* \brief broadcasts a memory region to every node from the root * \brief broadcasts a memory region to every node from the root
* *
* Example: int a = 1; Broadcast(&a, sizeof(a), root); * Example: int a = 1; Broadcast(&a, sizeof(a), root);
* \param sendrecv_data the pointer to the send/receive buffer, * \param sendrecv_data the pointer to the send/receive buffer,
* \param size the data size * \param size the data size
* \param root the process root * \param root the process root
@ -113,7 +119,7 @@ inline void Broadcast(std::vector<DType> *sendrecv_data, int root);
*/ */
inline void Broadcast(std::string *sendrecv_data, int root); inline void Broadcast(std::string *sendrecv_data, int root);
/*! /*!
* \brief performs in-place Allreduce on sendrecvbuf * \brief performs in-place Allreduce on sendrecvbuf
* this function is NOT thread-safe * this function is NOT thread-safe
* *
* Example Usage: the following code does an Allreduce and outputs the sum as the result * Example Usage: the following code does an Allreduce and outputs the sum as the result
@ -126,8 +132,8 @@ inline void Broadcast(std::string *sendrecv_data, int root);
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg) * \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf. * will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to pass into the lazy preprocessing function * \param prepare_arg argument used to pass into the lazy preprocessing function
* \tparam OP see namespace op, reduce operator * \tparam OP see namespace op, reduce operator
* \tparam DType data type * \tparam DType data type
*/ */
template<typename OP, typename DType> template<typename OP, typename DType>
@ -135,7 +141,7 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg) = NULL, void (*prepare_fun)(void *arg) = NULL,
void *prepare_arg = NULL); void *prepare_arg = NULL);
// C++11 support for lambda prepare function // C++11 support for lambda prepare function
#if __cplusplus >= 201103L #if DMLC_USE_CXX11
/*! /*!
* \brief performs in-place Allreduce, on sendrecvbuf * \brief performs in-place Allreduce, on sendrecvbuf
* with a prepare function specified by a lambda function * with a prepare function specified by a lambda function
@ -154,7 +160,7 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
* \param prepare_fun Lazy lambda preprocessing function, prepare_fun() will be invoked * \param prepare_fun Lazy lambda preprocessing function, prepare_fun() will be invoked
* by the function before performing Allreduce in order to initialize the data in sendrecvbuf. * by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \tparam OP see namespace op, reduce operator * \tparam OP see namespace op, reduce operator
* \tparam DType data type * \tparam DType data type
*/ */
template<typename OP, typename DType> template<typename OP, typename DType>
@ -168,18 +174,18 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
* is the same in every node * is the same in every node
* \param local_model pointer to the local model that is specific to the current node/rank * \param local_model pointer to the local model that is specific to the current node/rank
* this can be NULL when no local model is needed * this can be NULL when no local model is needed
* *
* \return the version number of the check point loaded * \return the version number of the check point loaded
* if returned version == 0, this means no model has been CheckPointed * if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, users should do the necessary initialization by themselves * the p_model is not touched, users should do the necessary initialization by themselves
* *
* Common usage example: * Common usage example:
* int iter = rabit::LoadCheckPoint(&model); * int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters(); * if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) { * for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce * do many things, include allreduce
* rabit::CheckPoint(model); * rabit::CheckPoint(model);
* } * }
* *
* \sa CheckPoint, VersionNumber * \sa CheckPoint, VersionNumber
*/ */
@ -188,7 +194,7 @@ inline int LoadCheckPoint(Serializable *global_model,
/*! /*!
* \brief checkpoints the model, meaning a stage of execution has finished. * \brief checkpoints the model, meaning a stage of execution has finished.
* every time we call check point, a version number will be increased by one * every time we call check point, a version number will be increased by one
* *
* \param global_model pointer to the globally shared model/state * \param global_model pointer to the globally shared model/state
* when calling this function, the caller needs to guarantee that the global_model * when calling this function, the caller needs to guarantee that the global_model
* is the same in every node * is the same in every node
@ -204,16 +210,16 @@ inline void CheckPoint(const Serializable *global_model,
/*! /*!
* \brief This function can be used to replace CheckPoint for global_model only, * \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met (see detailed explanation). * when certain condition is met (see detailed explanation).
* *
* This is a "lazy" checkpoint such that only the pointer to the global_model is * This is a "lazy" checkpoint such that only the pointer to the global_model is
* remembered and no memory copy is taken. To use this function, the user MUST ensure that: * remembered and no memory copy is taken. To use this function, the user MUST ensure that:
* The global_model must remain unchanged until the last call of Allreduce/Broadcast in the current version finishes. * The global_model must remain unchanged until the last call of Allreduce/Broadcast in the current version finishes.
* In other words, the global_model model can be changed only between the last call of * In other words, the global_model model can be changed only between the last call of
* Allreduce/Broadcast and LazyCheckPoint, both in the same version * Allreduce/Broadcast and LazyCheckPoint, both in the same version
* *
* For example, suppose the calling sequence is: * For example, suppose the calling sequence is:
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint/(or can be CheckPoint) * LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint/(or can be CheckPoint)
* *
* Then the user MUST only change the global_model in code3. * Then the user MUST only change the global_model in code3.
* *
* The use of LazyCheckPoint instead of CheckPoint will improve the efficiency of the program. * The use of LazyCheckPoint instead of CheckPoint will improve the efficiency of the program.
@ -235,36 +241,36 @@ namespace engine {
class ReduceHandle; class ReduceHandle;
} // namespace engine } // namespace engine
/*! /*!
* \brief template class to make customized reduce and all reduce easy * \brief template class to make customized reduce and all reduce easy
* Do not use reducer directly in the function you call Finalize, * Do not use reducer directly in the function you call Finalize,
* because the destructor can execute after Finalize * because the destructor can execute after Finalize
* \tparam DType data type that to be reduced * \tparam DType data type that to be reduced
* \tparam freduce the customized reduction function * \tparam freduce the customized reduction function
* DType must be a struct, with no pointer * DType must be a struct, with no pointer
*/ */
template<typename DType, void (*freduce)(DType &dst, const DType &src)> template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
class Reducer { class Reducer {
public: public:
Reducer(void); Reducer(void);
/*! /*!
* \brief customized in-place all reduce operation * \brief customized in-place all reduce operation
* \param sendrecvbuf the in place send-recv buffer * \param sendrecvbuf the in place send-recv buffer
* \param count number of elements to be reduced * \param count number of elements to be reduced
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg) * \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf. * will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to pass into the lazy preprocessing function * \param prepare_arg argument used to pass into the lazy preprocessing function
*/ */
inline void Allreduce(DType *sendrecvbuf, size_t count, inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg) = NULL, void (*prepare_fun)(void *arg) = NULL,
void *prepare_arg = NULL); void *prepare_arg = NULL);
#if __cplusplus >= 201103L #if DMLC_USE_CXX11
/*! /*!
* \brief customized in-place all reduce operation, with lambda function as preprocessor * \brief customized in-place all reduce operation, with lambda function as preprocessor
* \param sendrecvbuf pointer to the array of objects to be reduced * \param sendrecvbuf pointer to the array of objects to be reduced
* \param count number of elements to be reduced * \param count number of elements to be reduced
* \param prepare_fun lambda function executed to prepare the data, if necessary * \param prepare_fun lambda function executed to prepare the data, if necessary
*/ */
inline void Allreduce(DType *sendrecvbuf, size_t count, inline void Allreduce(DType *sendrecvbuf, size_t count,
std::function<void()> prepare_fun); std::function<void()> prepare_fun);
#endif #endif
@ -278,7 +284,7 @@ class Reducer {
* this class defines complex reducer handles all the data structure that can be * this class defines complex reducer handles all the data structure that can be
* serialized/deserialized into fixed size buffer * serialized/deserialized into fixed size buffer
* Do not use reducer directly in the function you call Finalize, because the destructor can execute after Finalize * Do not use reducer directly in the function you call Finalize, because the destructor can execute after Finalize
* *
* \tparam DType data type that to be reduced, DType must contain the following functions: * \tparam DType data type that to be reduced, DType must contain the following functions:
* \tparam freduce the customized reduction function * \tparam freduce the customized reduction function
* (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &src, size_t max_nbyte) * (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &src, size_t max_nbyte)
@ -288,7 +294,7 @@ class SerializeReducer {
public: public:
SerializeReducer(void); SerializeReducer(void);
/*! /*!
* \brief customized in-place all reduce operation * \brief customized in-place all reduce operation
* \param sendrecvobj pointer to the array of objects to be reduced * \param sendrecvobj pointer to the array of objects to be reduced
* \param max_nbyte maximum amount of memory needed to serialize each object * \param max_nbyte maximum amount of memory needed to serialize each object
* this includes budget limit for intermediate and final result * this includes budget limit for intermediate and final result
@ -296,14 +302,14 @@ class SerializeReducer {
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg) * \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf. * will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf.
* If the result of Allreduce can be recovered directly, then the prepare_func will NOT be called * If the result of Allreduce can be recovered directly, then the prepare_func will NOT be called
* \param prepare_arg argument used to pass into the lazy preprocessing function * \param prepare_arg argument used to pass into the lazy preprocessing function
*/ */
inline void Allreduce(DType *sendrecvobj, inline void Allreduce(DType *sendrecvobj,
size_t max_nbyte, size_t count, size_t max_nbyte, size_t count,
void (*prepare_fun)(void *arg) = NULL, void (*prepare_fun)(void *arg) = NULL,
void *prepare_arg = NULL); void *prepare_arg = NULL);
// C++11 support for lambda prepare function // C++11 support for lambda prepare function
#if __cplusplus >= 201103L #if DMLC_USE_CXX11
/*! /*!
* \brief customized in-place all reduce operation, with lambda function as preprocessor * \brief customized in-place all reduce operation, with lambda function as preprocessor
* \param sendrecvobj pointer to the array of objects to be reduced * \param sendrecvobj pointer to the array of objects to be reduced
@ -311,7 +317,7 @@ class SerializeReducer {
* this includes budget limit for intermediate and final result * this includes budget limit for intermediate and final result
* \param count number of elements to be reduced * \param count number of elements to be reduced
* \param prepare_fun lambda function executed to prepare the data, if necessary * \param prepare_fun lambda function executed to prepare the data, if necessary
*/ */
inline void Allreduce(DType *sendrecvobj, inline void Allreduce(DType *sendrecvobj,
size_t max_nbyte, size_t count, size_t max_nbyte, size_t count,
std::function<void()> prepare_fun); std::function<void()> prepare_fun);
@ -326,4 +332,4 @@ class SerializeReducer {
} // namespace rabit } // namespace rabit
// implementation of template functions // implementation of template functions
#include "./rabit/rabit-inl.h" #include "./rabit/rabit-inl.h"
#endif // RABIT_RABIT_H_ #endif // RABIT_RABIT_H_ // NOLINT(*)

View File

@ -4,8 +4,8 @@
* \brief utilities with different serializable implementations * \brief utilities with different serializable implementations
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef RABIT_UTILS_IO_H_ #ifndef RABIT_IO_H_
#define RABIT_UTILS_IO_H_ #define RABIT_IO_H_
#include <cstdio> #include <cstdio>
#include <vector> #include <vector>
#include <cstring> #include <cstring>
@ -51,6 +51,7 @@ struct MemoryFixSizeBuffer : public SeekStream {
virtual bool AtEnd(void) const { virtual bool AtEnd(void) const {
return curr_ptr_ == buffer_size_; return curr_ptr_ == buffer_size_;
} }
private: private:
/*! \brief in memory buffer */ /*! \brief in memory buffer */
char *p_buffer_; char *p_buffer_;
@ -93,6 +94,7 @@ struct MemoryBufferStream : public SeekStream {
virtual bool AtEnd(void) const { virtual bool AtEnd(void) const {
return curr_ptr_ == p_buffer_->length(); return curr_ptr_ == p_buffer_->length();
} }
private: private:
/*! \brief in memory buffer */ /*! \brief in memory buffer */
std::string *p_buffer_; std::string *p_buffer_;
@ -101,4 +103,4 @@ struct MemoryBufferStream : public SeekStream {
}; // class MemoryBufferStream }; // class MemoryBufferStream
} // namespace utils } // namespace utils
} // namespace rabit } // namespace rabit
#endif // RABIT_UTILS_IO_H_ #endif // RABIT_IO_H_

View File

@ -1,12 +1,15 @@
/*! /*!
* Copyright by Contributors
* \file rabit-inl.h * \file rabit-inl.h
* \brief implementation of inline template function for rabit interface * \brief implementation of inline template function for rabit interface
* *
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef RABIT_RABIT_INL_H #ifndef RABIT_RABIT_INL_H_
#define RABIT_RABIT_INL_H #define RABIT_RABIT_INL_H_
// use engine for implementation // use engine for implementation
#include <vector>
#include <string>
#include "./io.h" #include "./io.h"
#include "./utils.h" #include "./utils.h"
#include "../rabit.h" #include "../rabit.h"
@ -30,15 +33,15 @@ inline DataType GetType<int>(void) {
return kInt; return kInt;
} }
template<> template<>
inline DataType GetType<unsigned int>(void) { inline DataType GetType<unsigned int>(void) { // NOLINT(*)
return kUInt; return kUInt;
} }
template<> template<>
inline DataType GetType<long>(void) { inline DataType GetType<long>(void) { // NOLINT(*)
return kLong; return kLong;
} }
template<> template<>
inline DataType GetType<unsigned long>(void) { inline DataType GetType<unsigned long>(void) { // NOLINT(*)
return kULong; return kULong;
} }
template<> template<>
@ -50,54 +53,54 @@ inline DataType GetType<double>(void) {
return kDouble; return kDouble;
} }
template<> template<>
inline DataType GetType<long long>(void) { inline DataType GetType<long long>(void) { // NOLINT(*)
return kLongLong; return kLongLong;
} }
template<> template<>
inline DataType GetType<unsigned long long>(void) { inline DataType GetType<unsigned long long>(void) { // NOLINT(*)
return kULongLong; return kULongLong;
} }
} // namespace mpi } // namespace mpi
} // namespace engine } // namespace engine
namespace op { namespace op {
struct Max { struct Max {
const static engine::mpi::OpType kType = engine::mpi::kMax; static const engine::mpi::OpType kType = engine::mpi::kMax;
template<typename DType> template<typename DType>
inline static void Reduce(DType &dst, const DType &src) { inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
if (dst < src) dst = src; if (dst < src) dst = src;
} }
}; };
struct Min { struct Min {
const static engine::mpi::OpType kType = engine::mpi::kMin; static const engine::mpi::OpType kType = engine::mpi::kMin;
template<typename DType> template<typename DType>
inline static void Reduce(DType &dst, const DType &src) { inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
if (dst > src) dst = src; if (dst > src) dst = src;
} }
}; };
struct Sum { struct Sum {
const static engine::mpi::OpType kType = engine::mpi::kSum; static const engine::mpi::OpType kType = engine::mpi::kSum;
template<typename DType> template<typename DType>
inline static void Reduce(DType &dst, const DType &src) { inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
dst += src; dst += src;
} }
}; };
struct BitOR { struct BitOR {
const static engine::mpi::OpType kType = engine::mpi::kBitwiseOR; static const engine::mpi::OpType kType = engine::mpi::kBitwiseOR;
template<typename DType> template<typename DType>
inline static void Reduce(DType &dst, const DType &src) { inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
dst |= src; dst |= src;
} }
}; };
template<typename OP, typename DType> template<typename OP, typename DType>
inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) { inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
const DType *src = (const DType*)src_; const DType *src = (const DType*)src_;
DType *dst = (DType*)dst_; DType *dst = (DType*)dst_; // NOLINT(*)
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
OP::Reduce(dst[i], src[i]); OP::Reduce(dst[i], src[i]);
} }
} }
} // namespace op } // namespace op
// intialize the rabit engine // intialize the rabit engine
inline void Init(int argc, char *argv[]) { inline void Init(int argc, char *argv[]) {
@ -152,23 +155,23 @@ inline void Broadcast(std::string *sendrecv_data, int root) {
// perform inplace Allreduce // perform inplace Allreduce
template<typename OP, typename DType> template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count, inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg), void (*prepare_fun)(void *arg),
void *prepare_arg) { void *prepare_arg) {
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>, engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
engine::mpi::GetType<DType>(), OP::kType, prepare_fun, prepare_arg); engine::mpi::GetType<DType>(), OP::kType, prepare_fun, prepare_arg);
} }
// C++11 support for lambda prepare function // C++11 support for lambda prepare function
#if __cplusplus >= 201103L #if DMLC_USE_CXX11
inline void InvokeLambda_(void *fun) { inline void InvokeLambda_(void *fun) {
(*static_cast<std::function<void()>*>(fun))(); (*static_cast<std::function<void()>*>(fun))();
} }
template<typename OP, typename DType> template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> prepare_fun) { inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> prepare_fun) {
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>, engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda_, &prepare_fun); engine::mpi::GetType<DType>(), OP::kType, InvokeLambda_, &prepare_fun);
} }
#endif // C++11 #endif // C++11
// print message to the tracker // print message to the tracker
inline void TrackerPrint(const std::string &msg) { inline void TrackerPrint(const std::string &msg) {
@ -223,15 +226,16 @@ inline void ReducerSafe_(const void *src_, void *dst_, int len_, const MPI::Data
} }
} }
// function to perform reduction for Reducer // function to perform reduction for Reducer
template<typename DType, void (*freduce)(DType &dst, const DType &src)> template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
inline void ReducerAlign_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) { inline void ReducerAlign_(const void *src_, void *dst_,
int len_, const MPI::Datatype &dtype) {
const DType *psrc = reinterpret_cast<const DType*>(src_); const DType *psrc = reinterpret_cast<const DType*>(src_);
DType *pdst = reinterpret_cast<DType*>(dst_); DType *pdst = reinterpret_cast<DType*>(dst_);
for (int i = 0; i < len_; ++i) { for (int i = 0; i < len_; ++i) {
freduce(pdst[i], psrc[i]); freduce(pdst[i], psrc[i]);
} }
} }
template<typename DType, void (*freduce)(DType &dst, const DType &src)> template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
inline Reducer<DType, freduce>::Reducer(void) { inline Reducer<DType, freduce>::Reducer(void) {
// it is safe to directly use handle for aligned data types // it is safe to directly use handle for aligned data types
if (sizeof(DType) == 8 || sizeof(DType) == 4 || sizeof(DType) == 1) { if (sizeof(DType) == 8 || sizeof(DType) == 4 || sizeof(DType) == 1) {
@ -240,7 +244,7 @@ inline Reducer<DType, freduce>::Reducer(void) {
this->handle_.Init(ReducerSafe_<DType, freduce>, sizeof(DType)); this->handle_.Init(ReducerSafe_<DType, freduce>, sizeof(DType));
} }
} }
template<typename DType, void (*freduce)(DType &dst, const DType &src)> template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count, inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg), void (*prepare_fun)(void *arg),
void *prepare_arg) { void *prepare_arg) {
@ -248,13 +252,14 @@ inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
} }
// function to perform reduction for SerializeReducer // function to perform reduction for SerializeReducer
template<typename DType> template<typename DType>
inline void SerializeReducerFunc_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) { inline void SerializeReducerFunc_(const void *src_, void *dst_,
int len_, const MPI::Datatype &dtype) {
int nbytes = engine::ReduceHandle::TypeSize(dtype); int nbytes = engine::ReduceHandle::TypeSize(dtype);
// temp space // temp space
DType tsrc, tdst; DType tsrc, tdst;
for (int i = 0; i < len_; ++i) { for (int i = 0; i < len_; ++i) {
utils::MemoryFixSizeBuffer fsrc((char*)(src_) + i * nbytes, nbytes); utils::MemoryFixSizeBuffer fsrc((char*)(src_) + i * nbytes, nbytes); // NOLINT(*)
utils::MemoryFixSizeBuffer fdst((char*)(dst_) + i * nbytes, nbytes); utils::MemoryFixSizeBuffer fdst((char*)(dst_) + i * nbytes, nbytes); // NOLINT(*)
tsrc.Load(fsrc); tsrc.Load(fsrc);
tdst.Load(fdst); tdst.Load(fdst);
// govern const check // govern const check
@ -296,8 +301,8 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
// setup closure // setup closure
SerializeReduceClosure<DType> c; SerializeReduceClosure<DType> c;
c.sendrecvobj = sendrecvobj; c.max_nbyte = max_nbyte; c.count = count; c.sendrecvobj = sendrecvobj; c.max_nbyte = max_nbyte; c.count = count;
c.prepare_fun = prepare_fun; c.prepare_arg = prepare_arg; c.p_buffer = &buffer_; c.prepare_fun = prepare_fun; c.prepare_arg = prepare_arg; c.p_buffer = &buffer_;
// invoke here // invoke here
handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count, handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count,
SerializeReduceClosure<DType>::Invoke, &c); SerializeReduceClosure<DType>::Invoke, &c);
for (size_t i = 0; i < count; ++i) { for (size_t i = 0; i < count; ++i) {
@ -306,8 +311,8 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
} }
} }
#if __cplusplus >= 201103L #if DMLC_USE_CXX11
template<typename DType, void (*freduce)(DType &dst, const DType &src)> template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)g
inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count, inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
std::function<void()> prepare_fun) { std::function<void()> prepare_fun) {
this->Allreduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun); this->Allreduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun);
@ -320,4 +325,4 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
} }
#endif #endif
} // namespace rabit } // namespace rabit
#endif #endif // RABIT_RABIT_INL_H_

View File

@ -1,4 +1,5 @@
/*! /*!
* Copyright by Contributors
* \file timer.h * \file timer.h
* \brief This file defines the utils for timing * \brief This file defines the utils for timing
* \author Tianqi Chen, Nacho, Tianyi * \author Tianqi Chen, Nacho, Tianyi
@ -18,7 +19,6 @@ namespace utils {
* \brief return time in seconds, not cross platform, avoid to use this in most places * \brief return time in seconds, not cross platform, avoid to use this in most places
*/ */
inline double GetTime(void) { inline double GetTime(void) {
// TODO: use c++11 chrono when c++11 was available
#ifdef __MACH__ #ifdef __MACH__
clock_serv_t cclock; clock_serv_t cclock;
mach_timespec_t mts; mach_timespec_t mts;
@ -32,7 +32,6 @@ inline double GetTime(void) {
utils::Check(clock_gettime(CLOCK_REALTIME, &ts) == 0, "failed to get time"); utils::Check(clock_gettime(CLOCK_REALTIME, &ts) == 0, "failed to get time");
return static_cast<double>(ts.tv_sec) + static_cast<double>(ts.tv_nsec) * 1e-9; return static_cast<double>(ts.tv_sec) + static_cast<double>(ts.tv_nsec) * 1e-9;
#else #else
// TODO: add MSVC macro, and MSVC timer
return static_cast<double>(time(NULL)); return static_cast<double>(time(NULL));
#endif #endif
#endif #endif

View File

@ -27,7 +27,7 @@
#else #else
#ifdef _FILE_OFFSET_BITS #ifdef _FILE_OFFSET_BITS
#if _FILE_OFFSET_BITS == 32 #if _FILE_OFFSET_BITS == 32
#pragma message ("Warning: FILE OFFSET BITS defined to be 32 bit") #pragma message("Warning: FILE OFFSET BITS defined to be 32 bit")
#endif #endif
#endif #endif
@ -59,17 +59,17 @@ namespace utils {
const int kPrintBuffer = 1 << 12; const int kPrintBuffer = 1 << 12;
#ifndef RABIT_CUSTOMIZE_MSG_ #ifndef RABIT_CUSTOMIZE_MSG_
/*! /*!
* \brief handling of Assert error, caused by inappropriate input * \brief handling of Assert error, caused by inappropriate input
* \param msg error message * \param msg error message
*/ */
inline void HandleAssertError(const char *msg) { inline void HandleAssertError(const char *msg) {
fprintf(stderr, "AssertError:%s\n", msg); fprintf(stderr, "AssertError:%s\n", msg);
exit(-1); exit(-1);
} }
/*! /*!
* \brief handling of Check error, caused by inappropriate input * \brief handling of Check error, caused by inappropriate input
* \param msg error message * \param msg error message
*/ */
inline void HandleCheckError(const char *msg) { inline void HandleCheckError(const char *msg) {
fprintf(stderr, "%s\n", msg); fprintf(stderr, "%s\n", msg);
@ -163,7 +163,7 @@ inline std::FILE *FopenCheck(const char *fname, const char *flag) {
// easy utils that can be directly accessed in xgboost // easy utils that can be directly accessed in xgboost
/*! \brief get the beginning address of a vector */ /*! \brief get the beginning address of a vector */
template<typename T> template<typename T>
inline T *BeginPtr(std::vector<T> &vec) { inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*)
if (vec.size() == 0) { if (vec.size() == 0) {
return NULL; return NULL;
} else { } else {
@ -172,14 +172,14 @@ inline T *BeginPtr(std::vector<T> &vec) {
} }
/*! \brief get the beginning address of a vector */ /*! \brief get the beginning address of a vector */
template<typename T> template<typename T>
inline const T *BeginPtr(const std::vector<T> &vec) { inline const T *BeginPtr(const std::vector<T> &vec) { // NOLINT(*)
if (vec.size() == 0) { if (vec.size() == 0) {
return NULL; return NULL;
} else { } else {
return &vec[0]; return &vec[0];
} }
} }
inline char* BeginPtr(std::string &str) { inline char* BeginPtr(std::string &str) { // NOLINT(*)
if (str.length() == 0) return NULL; if (str.length() == 0) return NULL;
return &str[0]; return &str[0];
} }

View File

@ -4,8 +4,8 @@
* \brief defines serializable interface of rabit * \brief defines serializable interface of rabit
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef RABIT_RABIT_SERIALIZABLE_H_ #ifndef RABIT_SERIALIZABLE_H_
#define RABIT_RABIT_SERIALIZABLE_H_ #define RABIT_SERIALIZABLE_H_
#include <vector> #include <vector>
#include <string> #include <string>
#include "./rabit/utils.h" #include "./rabit/utils.h"
@ -13,15 +13,15 @@
namespace rabit { namespace rabit {
/*! /*!
* \brief defines stream used in rabit * \brief defines stream used in rabit
* see definition of Stream in dmlc/io.h * see definition of Stream in dmlc/io.h
*/ */
typedef dmlc::Stream Stream; typedef dmlc::Stream Stream;
/*! /*!
* \brief defines serializable objects used in rabit * \brief defines serializable objects used in rabit
* see definition of Serializable in dmlc/io.h * see definition of Serializable in dmlc/io.h
*/ */
typedef dmlc::Serializable Serializable; typedef dmlc::Serializable Serializable;
} // namespace rabit } // namespace rabit
#endif // RABIT_RABIT_SERIALIZABLE_H_ #endif // RABIT_SERIALIZABLE_H_

8
scripts/travis_runtest.sh Executable file
View File

@ -0,0 +1,8 @@
#!/bin/bash
make -f test.mk model_recover_10_10k || exit -1
make -f test.mk model_recover_10_10k_die_same || exit -1
make -f test.mk local_recover_10_10k || exit -1
make -f test.mk pylocal_recover_10_10k || exit -1
make -f test.mk lazy_recover_10_10k_die_hard || exit -1
make -f test.mk lazy_recover_10_10k_die_same || exit -1
make -f test.mk ringallreduce_10_10k || exit -1

22
scripts/travis_script.sh Executable file
View File

@ -0,0 +1,22 @@
#!/bin/bash
# main script of travis
if [ ${TASK} == "lint" ]; then
make lint || exit -1
fi
if [ ${TASK} == "doc" ]; then
make doc 2>log.txt
(cat log.txt|grep warning) && exit -1
fi
if [ ${TASK} == "build" ]; then
make all || exit -1
fi
if [ ${TASK} == "test" ]; then
cd test
make all || exit -1
./travis_runtest.sh || exit -1
fi

View File

@ -94,7 +94,8 @@ void AllreduceBase::Init(void) {
} }
} }
if (dmlc_role != "worker") { if (dmlc_role != "worker") {
fprintf(stderr, "Rabit Module currently only work with dmlc worker, quit this program by exit 0\n"); fprintf(stderr, "Rabit Module currently only work with dmlc worker"\
", quit this program by exit 0\n");
exit(0); exit(0);
} }
// clear the setting before start reconnection // clear the setting before start reconnection
@ -134,7 +135,7 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
// util to parse data with unit suffix // util to parse data with unit suffix
inline size_t ParseUnit(const char *name, const char *val) { inline size_t ParseUnit(const char *name, const char *val) {
char unit; char unit;
unsigned long amt; unsigned long amt; // NOLINT(*)
int n = sscanf(val, "%lu%c", &amt, &unit); int n = sscanf(val, "%lu%c", &amt, &unit);
size_t amount = amt; size_t amount = amt;
if (n == 2) { if (n == 2) {
@ -154,7 +155,7 @@ inline size_t ParseUnit(const char *name, const char *val) {
} }
} }
/*! /*!
* \brief set parameters to the engine * \brief set parameters to the engine
* \param name parameter name * \param name parameter name
* \param val parameter value * \param val parameter value
*/ */
@ -258,7 +259,7 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
} else { } else {
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close(); if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
} }
} }
int ngood = static_cast<int>(good_link.size()); int ngood = static_cast<int>(good_link.size());
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood), Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
"ReConnectLink failure 5"); "ReConnectLink failure 5");
@ -359,7 +360,7 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
* The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce. * The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
* It only means the current node get the correct result of Allreduce. * It only means the current node get the correct result of Allreduce.
* However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast * However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
* *
* \param sendrecvbuf_ buffer for both sending and recving data * \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have * \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced * \param count number of elements to be reduced
@ -440,7 +441,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
selecter.WatchRead(links[i].sock); selecter.WatchRead(links[i].sock);
} }
// size_write <= size_read // size_write <= size_read
if (links[i].size_write != total_size){ if (links[i].size_write != total_size) {
if (links[i].size_write < size_down_in) { if (links[i].size_write < size_down_in) {
selecter.WatchWrite(links[i].sock); selecter.WatchWrite(links[i].sock);
} }
@ -477,7 +478,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
size_t max_reduce = total_size; size_t max_reduce = total_size;
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (i != parent_index) { if (i != parent_index) {
max_reduce= std::min(max_reduce, links[i].size_read); max_reduce = std::min(max_reduce, links[i].size_read);
utils::Assert(buffer_size == 0 || buffer_size == links[i].buffer_size, utils::Assert(buffer_size == 0 || buffer_size == links[i].buffer_size,
"buffer size inconsistent"); "buffer size inconsistent");
buffer_size = links[i].buffer_size; buffer_size = links[i].buffer_size;
@ -525,7 +526,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
ssize_t len = links[parent_index].sock. ssize_t len = links[parent_index].sock.
Recv(sendrecvbuf + size_down_in, total_size - size_down_in); Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
if (len == 0) { if (len == 0) {
links[parent_index].sock.Close(); links[parent_index].sock.Close();
return ReportError(&links[parent_index], kRecvZeroLen); return ReportError(&links[parent_index], kRecvZeroLen);
} }
if (len != -1) { if (len != -1) {
@ -670,7 +671,7 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
size_t slice_begin, size_t slice_begin,
size_t slice_end, size_t slice_end,
size_t size_prev_slice) { size_t size_prev_slice) {
// read from next link and send to prev one // read from next link and send to prev one
LinkRecord &prev = *ring_prev, &next = *ring_next; LinkRecord &prev = *ring_prev, &next = *ring_next;
// need to reply on special rank structure // need to reply on special rank structure
utils::Assert(next.rank == (rank + 1) % world_size && utils::Assert(next.rank == (rank + 1) % world_size &&
@ -678,11 +679,11 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
"need to assume rank structure"); "need to assume rank structure");
// send recv buffer // send recv buffer
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_); char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
const size_t stop_read = total_size + slice_begin; const size_t stop_read = total_size + slice_begin;
const size_t stop_write = total_size + slice_begin - size_prev_slice; const size_t stop_write = total_size + slice_begin - size_prev_slice;
size_t write_ptr = slice_begin; size_t write_ptr = slice_begin;
size_t read_ptr = slice_end; size_t read_ptr = slice_end;
while (true) { while (true) {
// select helper // select helper
bool finished = true; bool finished = true;
@ -733,7 +734,7 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
/*! /*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, * \brief perform in-place allreduce, on sendrecvbuf, this function can fail,
* and will return the cause of failure * and will return the cause of failure
* *
* Ring-based algorithm * Ring-based algorithm
* *
* \param sendrecvbuf_ buffer for both sending and recving data * \param sendrecvbuf_ buffer for both sending and recving data
@ -748,7 +749,7 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
size_t type_nbytes, size_t type_nbytes,
size_t count, size_t count,
ReduceFunction reducer) { ReduceFunction reducer) {
// read from next link and send to prev one // read from next link and send to prev one
LinkRecord &prev = *ring_prev, &next = *ring_next; LinkRecord &prev = *ring_prev, &next = *ring_next;
// need to reply on special rank structure // need to reply on special rank structure
utils::Assert(next.rank == (rank + 1) % world_size && utils::Assert(next.rank == (rank + 1) % world_size &&
@ -757,7 +758,7 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
// total size of message // total size of message
const size_t total_size = type_nbytes * count; const size_t total_size = type_nbytes * count;
size_t n = static_cast<size_t>(world_size); size_t n = static_cast<size_t>(world_size);
size_t step = (count + n - 1) / n; size_t step = (count + n - 1) / n;
size_t r = static_cast<size_t>(next.rank); size_t r = static_cast<size_t>(next.rank);
size_t write_ptr = std::min(r * step, count) * type_nbytes; size_t write_ptr = std::min(r * step, count) * type_nbytes;
size_t read_ptr = std::min((r + 1) * step, count) * type_nbytes; size_t read_ptr = std::min((r + 1) * step, count) * type_nbytes;
@ -830,7 +831,7 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
if (ret != kSuccess) return ReportError(&prev, ret); if (ret != kSuccess) return ReportError(&prev, ret);
} }
} }
} }
return kSuccess; return kSuccess;
} }
/*! /*!
@ -857,7 +858,7 @@ AllreduceBase::TryAllreduceRing(void *sendrecvbuf_,
size_t end = std::min((rank + 1) * step, count) * type_nbytes; size_t end = std::min((rank + 1) * step, count) * type_nbytes;
// previous rank // previous rank
int prank = ring_prev->rank; int prank = ring_prev->rank;
// get rank of previous // get rank of previous
return TryAllgatherRing return TryAllgatherRing
(sendrecvbuf_, type_nbytes * count, (sendrecvbuf_, type_nbytes * count,
begin, end, begin, end,

View File

@ -42,7 +42,7 @@ class AllreduceBase : public IEngine {
// shutdown the engine // shutdown the engine
virtual void Shutdown(void); virtual void Shutdown(void);
/*! /*!
* \brief set parameters to the engine * \brief set parameters to the engine
* \param name parameter name * \param name parameter name
* \param val parameter value * \param val parameter value
*/ */
@ -72,7 +72,7 @@ class AllreduceBase : public IEngine {
return host_uri; return host_uri;
} }
/*! /*!
* \brief perform in-place allreduce, on sendrecvbuf * \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe * this function is NOT thread-safe
* \param sendrecvbuf_ buffer for both sending and recving data * \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have * \param type_nbytes the unit number of bytes the type have
@ -82,7 +82,7 @@ class AllreduceBase : public IEngine {
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_. * will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function * \param prepare_arg argument used to passed into the lazy preprocessing function
*/ */
virtual void Allreduce(void *sendrecvbuf_, virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes, size_t type_nbytes,
size_t count, size_t count,
@ -117,14 +117,14 @@ class AllreduceBase : public IEngine {
* \return the version number of check point loaded * \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed * if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves * the p_model is not touched, user should do necessary initialization by themselves
* *
* Common usage example: * Common usage example:
* int iter = rabit::LoadCheckPoint(&model); * int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters(); * if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) { * for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce * do many things, include allreduce
* rabit::CheckPoint(model); * rabit::CheckPoint(model);
* } * }
* *
* \sa CheckPoint, VersionNumber * \sa CheckPoint, VersionNumber
*/ */
@ -135,7 +135,7 @@ class AllreduceBase : public IEngine {
/*! /*!
* \brief checkpoint the model, meaning we finished a stage of execution * \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one * every time we call check point, there is a version number which will increase by one
* *
* \param global_model pointer to the globally shared model/state * \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model * when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes * is the same in all nodes
@ -155,16 +155,16 @@ class AllreduceBase : public IEngine {
/*! /*!
* \brief This function can be used to replace CheckPoint for global_model only, * \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met(see detailed expplaination). * when certain condition is met(see detailed expplaination).
* *
* This is a "lazy" checkpoint such that only the pointer to global_model is * This is a "lazy" checkpoint such that only the pointer to global_model is
* remembered and no memory copy is taken. To use this function, the user MUST ensure that: * remembered and no memory copy is taken. To use this function, the user MUST ensure that:
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs. * The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
* In another words, global_model model can be changed only between last call of * In another words, global_model model can be changed only between last call of
* Allreduce/Broadcast and LazyCheckPoint in current version * Allreduce/Broadcast and LazyCheckPoint in current version
* *
* For example, suppose the calling sequence is: * For example, suppose the calling sequence is:
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint * LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
* *
* If user can only changes global_model in code3, then LazyCheckPoint can be used to * If user can only changes global_model in code3, then LazyCheckPoint can be used to
* improve efficiency of the program. * improve efficiency of the program.
* \param global_model pointer to the globally shared model/state * \param global_model pointer to the globally shared model/state
@ -191,8 +191,8 @@ class AllreduceBase : public IEngine {
virtual void InitAfterException(void) { virtual void InitAfterException(void) {
utils::Error("InitAfterException: not implemented"); utils::Error("InitAfterException: not implemented");
} }
/*! /*!
* \brief report current status to the job tracker * \brief report current status to the job tracker
* depending on the job tracker we are in * depending on the job tracker we are in
*/ */
inline void ReportStatus(void) const { inline void ReportStatus(void) const {
@ -213,7 +213,7 @@ class AllreduceBase : public IEngine {
kRecvZeroLen, kRecvZeroLen,
/*! \brief a neighbor node go down, the connection is dropped */ /*! \brief a neighbor node go down, the connection is dropped */
kSockError, kSockError,
/*! /*!
* \brief another node which is not my neighbor go down, * \brief another node which is not my neighbor go down,
* get Out-of-Band exception notification from my neighbor * get Out-of-Band exception notification from my neighbor
*/ */
@ -225,7 +225,7 @@ class AllreduceBase : public IEngine {
ReturnTypeEnum value; ReturnTypeEnum value;
// constructor // constructor
ReturnType() {} ReturnType() {}
ReturnType(ReturnTypeEnum value) : value(value){} ReturnType(ReturnTypeEnum value) : value(value) {} // NOLINT(*)
inline bool operator==(const ReturnTypeEnum &v) const { inline bool operator==(const ReturnTypeEnum &v) const {
return value == v; return value == v;
} }
@ -235,11 +235,11 @@ class AllreduceBase : public IEngine {
}; };
/*! \brief translate errno to return type */ /*! \brief translate errno to return type */
inline static ReturnType Errno2Return() { inline static ReturnType Errno2Return() {
int errsv = utils::Socket::GetLastError(); int errsv = utils::Socket::GetLastError();
if (errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == 0) return kSuccess; if (errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == 0) return kSuccess;
#ifdef _WIN32 #ifdef _WIN32
if (errsv == WSAEWOULDBLOCK) return kSuccess; if (errsv == WSAEWOULDBLOCK) return kSuccess;
if (errsv == WSAECONNRESET) return kConnReset; if (errsv == WSAECONNRESET) return kConnReset;
#endif #endif
if (errsv == ECONNRESET) return kConnReset; if (errsv == ECONNRESET) return kConnReset;
return kSockError; return kSockError;
@ -260,7 +260,7 @@ class AllreduceBase : public IEngine {
// buffer size, in bytes // buffer size, in bytes
size_t buffer_size; size_t buffer_size;
// constructor // constructor
LinkRecord(void) LinkRecord(void)
: buffer_head(NULL), buffer_size(0) { : buffer_head(NULL), buffer_size(0) {
} }
// initialize buffer // initialize buffer
@ -377,7 +377,7 @@ class AllreduceBase : public IEngine {
* The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce. * The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
* It only means the current node get the correct result of Allreduce. * It only means the current node get the correct result of Allreduce.
* However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast * However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
* *
* \param sendrecvbuf_ buffer for both sending and recving data * \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have * \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced * \param count number of elements to be reduced
@ -397,7 +397,7 @@ class AllreduceBase : public IEngine {
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType * \sa ReturnType
*/ */
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root); ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
/*! /*!
* \brief perform in-place allreduce, on sendrecvbuf, * \brief perform in-place allreduce, on sendrecvbuf,
* this function implements tree-shape reduction * this function implements tree-shape reduction
@ -433,14 +433,14 @@ class AllreduceBase : public IEngine {
size_t size_prev_slice); size_t size_prev_slice);
/*! /*!
* \brief perform in-place allreduce, reduce on the sendrecvbuf, * \brief perform in-place allreduce, reduce on the sendrecvbuf,
* *
* after the function, node k get k-th segment of the reduction result * after the function, node k get k-th segment of the reduction result
* the k-th segment is defined by [k * step, min((k + 1) * step,count) ) * the k-th segment is defined by [k * step, min((k + 1) * step,count) )
* where step = ceil(count / world_size) * where step = ceil(count / world_size)
* *
* \param sendrecvbuf_ buffer for both sending and recving data * \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have * \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced * \param count number of elements to be reduced
* \param reducer reduce function * \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType, TryAllreduce * \sa ReturnType, TryAllreduce
@ -465,7 +465,7 @@ class AllreduceBase : public IEngine {
size_t count, size_t count,
ReduceFunction reducer); ReduceFunction reducer);
/*! /*!
* \brief function used to report error when a link goes wrong * \brief function used to report error when a link goes wrong
* \param link the pointer to the link who causes the error * \param link the pointer to the link who causes the error
* \param err the error type * \param err the error type
*/ */
@ -522,4 +522,4 @@ class AllreduceBase : public IEngine {
}; };
} // namespace engine } // namespace engine
} // namespace rabit } // namespace rabit
#endif // RABIT_ALLREDUCE_BASE_H #endif // RABIT_ALLREDUCE_BASE_H_

View File

@ -1,8 +1,9 @@
/*! /*!
* Copyright by Contributors
* \file allreduce_mock.h * \file allreduce_mock.h
* \brief Mock test module of AllReduce engine, * \brief Mock test module of AllReduce engine,
* insert failures in certain call point, to test if the engine is robust to failure * insert failures in certain call point, to test if the engine is robust to failure
* *
* \author Ignacio Cano, Tianqi Chen * \author Ignacio Cano, Tianqi Chen
*/ */
#ifndef RABIT_ALLREDUCE_MOCK_H_ #ifndef RABIT_ALLREDUCE_MOCK_H_
@ -68,7 +69,7 @@ class AllreduceMock : public AllreduceRobust {
DummySerializer dum; DummySerializer dum;
ComboSerializer com(global_model, local_model); ComboSerializer com(global_model, local_model);
return AllreduceRobust::LoadCheckPoint(&dum, &com); return AllreduceRobust::LoadCheckPoint(&dum, &com);
} }
} }
virtual void CheckPoint(const Serializable *global_model, virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model) { const Serializable *local_model) {
@ -100,6 +101,7 @@ class AllreduceMock : public AllreduceRobust {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint"); this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
AllreduceRobust::LazyCheckPoint(global_model); AllreduceRobust::LazyCheckPoint(global_model);
} }
protected: protected:
// force checkpoint to local // force checkpoint to local
int force_local; int force_local;
@ -108,7 +110,7 @@ class AllreduceMock : public AllreduceRobust {
// sum of allreduce // sum of allreduce
double tsum_allreduce; double tsum_allreduce;
double time_checkpoint; double time_checkpoint;
private: private:
struct DummySerializer : public Serializable { struct DummySerializer : public Serializable {
virtual void Load(Stream *fi) { virtual void Load(Stream *fi) {
@ -126,7 +128,7 @@ class AllreduceMock : public AllreduceRobust {
} }
ComboSerializer(const Serializable *lhs, const Serializable *rhs) ComboSerializer(const Serializable *lhs, const Serializable *rhs)
: lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) { : lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) {
} }
virtual void Load(Stream *fi) { virtual void Load(Stream *fi) {
if (lhs != NULL) lhs->Load(fi); if (lhs != NULL) lhs->Load(fi);
if (rhs != NULL) rhs->Load(fi); if (rhs != NULL) rhs->Load(fi);
@ -143,10 +145,10 @@ class AllreduceMock : public AllreduceRobust {
int seqno; int seqno;
int ntrial; int ntrial;
MockKey(void) {} MockKey(void) {}
MockKey(int rank, int version, int seqno, int ntrial) MockKey(int rank, int version, int seqno, int ntrial)
: rank(rank), version(version), seqno(seqno), ntrial(ntrial) {} : rank(rank), version(version), seqno(seqno), ntrial(ntrial) {}
inline bool operator==(const MockKey &b) const { inline bool operator==(const MockKey &b) const {
return rank == b.rank && return rank == b.rank &&
version == b.version && version == b.version &&
seqno == b.seqno && seqno == b.seqno &&
ntrial == b.ntrial; ntrial == b.ntrial;
@ -173,4 +175,4 @@ class AllreduceMock : public AllreduceRobust {
}; };
} // namespace engine } // namespace engine
} // namespace rabit } // namespace rabit
#endif // RABIT_ALLREDUCE_MOCK_H_ #endif // RABIT_ALLREDUCE_MOCK_H_

View File

@ -2,17 +2,17 @@
* Copyright (c) 2014 by Contributors * Copyright (c) 2014 by Contributors
* \file allreduce_robust-inl.h * \file allreduce_robust-inl.h
* \brief implementation of inline template function in AllreduceRobust * \brief implementation of inline template function in AllreduceRobust
* *
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef RABIT_ENGINE_ROBUST_INL_H_ #ifndef RABIT_ALLREDUCE_ROBUST_INL_H_
#define RABIT_ENGINE_ROBUST_INL_H_ #define RABIT_ALLREDUCE_ROBUST_INL_H_
#include <vector> #include <vector>
namespace rabit { namespace rabit {
namespace engine { namespace engine {
/*! /*!
* \brief run message passing algorithm on the allreduce tree * \brief run message passing algorithm on the allreduce tree
* the result is edge message stored in p_edge_in and p_edge_out * the result is edge message stored in p_edge_in and p_edge_out
* \param node_value the value associated with current node * \param node_value the value associated with current node
* \param p_edge_in used to store input message from each of the edge * \param p_edge_in used to store input message from each of the edge
@ -35,7 +35,7 @@ inline AllreduceRobust::ReturnType
AllreduceRobust::MsgPassing(const NodeType &node_value, AllreduceRobust::MsgPassing(const NodeType &node_value,
std::vector<EdgeType> *p_edge_in, std::vector<EdgeType> *p_edge_in,
std::vector<EdgeType> *p_edge_out, std::vector<EdgeType> *p_edge_out,
EdgeType (*func) EdgeType(*func)
(const NodeType &node_value, (const NodeType &node_value,
const std::vector<EdgeType> &edge_in, const std::vector<EdgeType> &edge_in,
size_t out_index)) { size_t out_index)) {
@ -80,8 +80,16 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
selecter.WatchRead(links[i].sock); selecter.WatchRead(links[i].sock);
} }
break; break;
case 1: if (i == parent_index) selecter.WatchWrite(links[i].sock); break; case 1:
case 2: if (i == parent_index) selecter.WatchRead(links[i].sock); break; if (i == parent_index) {
selecter.WatchWrite(links[i].sock);
}
break;
case 2:
if (i == parent_index) {
selecter.WatchRead(links[i].sock);
}
break;
case 3: case 3:
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) { if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
selecter.WatchWrite(links[i].sock); selecter.WatchWrite(links[i].sock);
@ -158,4 +166,4 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
} }
} // namespace engine } // namespace engine
} // namespace rabit } // namespace rabit
#endif // RABIT_ENGINE_ROBUST_INL_H_ #endif // RABIT_ALLREDUCE_ROBUST_INL_H_

View File

@ -27,7 +27,7 @@ AllreduceRobust::AllreduceRobust(void) {
result_buffer_round = 1; result_buffer_round = 1;
global_lazycheck = NULL; global_lazycheck = NULL;
use_local_model = -1; use_local_model = -1;
recover_counter = 0; recover_counter = 0;
env_vars.push_back("rabit_global_replica"); env_vars.push_back("rabit_global_replica");
env_vars.push_back("rabit_local_replica"); env_vars.push_back("rabit_local_replica");
} }
@ -49,7 +49,7 @@ void AllreduceRobust::Shutdown(void) {
AllreduceBase::Shutdown(); AllreduceBase::Shutdown();
} }
/*! /*!
* \brief set parameters to the engine * \brief set parameters to the engine
* \param name parameter name * \param name parameter name
* \param val parameter value * \param val parameter value
*/ */
@ -61,7 +61,7 @@ void AllreduceRobust::SetParam(const char *name, const char *val) {
} }
} }
/*! /*!
* \brief perform in-place allreduce, on sendrecvbuf * \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe * this function is NOT thread-safe
* \param sendrecvbuf_ buffer for both sending and recving data * \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have * \param type_nbytes the unit number of bytes the type have
@ -147,14 +147,14 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
* \return the version number of check point loaded * \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed * if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves * the p_model is not touched, user should do necessary initialization by themselves
* *
* Common usage example: * Common usage example:
* int iter = rabit::LoadCheckPoint(&model); * int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters(); * if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) { * for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce * do many things, include allreduce
* rabit::CheckPoint(model); * rabit::CheckPoint(model);
* } * }
* *
* \sa CheckPoint, VersionNumber * \sa CheckPoint, VersionNumber
*/ */
@ -208,7 +208,7 @@ int AllreduceRobust::LoadCheckPoint(Serializable *global_model,
* \brief internal consistency check function, * \brief internal consistency check function,
* use check to ensure user always call CheckPoint/LoadCheckPoint * use check to ensure user always call CheckPoint/LoadCheckPoint
* with or without local but not both, this function will set the approperiate settings * with or without local but not both, this function will set the approperiate settings
* in the first call of LoadCheckPoint/CheckPoint * in the first call of LoadCheckPoint/CheckPoint
* *
* \param with_local whether the user calls CheckPoint with local model * \param with_local whether the user calls CheckPoint with local model
*/ */
@ -224,14 +224,14 @@ void AllreduceRobust::LocalModelCheck(bool with_local) {
num_local_replica = 0; num_local_replica = 0;
} }
} else { } else {
utils::Check(use_local_model == int(with_local), utils::Check(use_local_model == static_cast<int>(with_local),
"Can only call Checkpoint/LoadCheckPoint always with"\ "Can only call Checkpoint/LoadCheckPoint always with"\
"or without local_model, but not mixed case"); "or without local_model, but not mixed case");
} }
} }
/*! /*!
* \brief internal implementation of checkpoint, support both lazy and normal way * \brief internal implementation of checkpoint, support both lazy and normal way
* *
* \param global_model pointer to the globally shared model/state * \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model * when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes * is the same in all nodes
@ -423,7 +423,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
* recover links according to the error type reported * recover links according to the error type reported
* if there is no error, return true * if there is no error, return true
* \param err_type the type of error happening in the system * \param err_type the type of error happening in the system
* \return true if err_type is kSuccess, false otherwise * \return true if err_type is kSuccess, false otherwise
*/ */
bool AllreduceRobust::CheckAndRecover(ReturnType err_type) { bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
if (err_type == kSuccess) return true; if (err_type == kSuccess) return true;
@ -488,7 +488,7 @@ ShortestDist(const std::pair<bool, size_t> &node_value,
* \brief message passing function, used to decide the * \brief message passing function, used to decide the
* data request from each edge, whether need to request data from certain edge * data request from each edge, whether need to request data from certain edge
* \param node_value a pair of request_data and best_link * \param node_value a pair of request_data and best_link
* request_data stores whether current node need to request data * request_data stores whether current node need to request data
* best_link gives the best edge index to fetch the data * best_link gives the best edge index to fetch the data
* \param req_in the data request from incoming edges * \param req_in the data request from incoming edges
* \param out_index the edge index of output link * \param out_index the edge index of output link
@ -524,7 +524,7 @@ inline char DataRequest(const std::pair<bool, int> &node_value,
* *
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType * \sa ReturnType
*/ */
AllreduceRobust::ReturnType AllreduceRobust::ReturnType
AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role, AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role,
size_t *p_size, size_t *p_size,
@ -586,7 +586,7 @@ AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role,
* *
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType, TryDecideRouting * \sa ReturnType, TryDecideRouting
*/ */
AllreduceRobust::ReturnType AllreduceRobust::ReturnType
AllreduceRobust::TryRecoverData(RecoverType role, AllreduceRobust::TryRecoverData(RecoverType role,
void *sendrecvbuf_, void *sendrecvbuf_,
@ -644,7 +644,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
if (role == kRequestData) { if (role == kRequestData) {
const int pid = recv_link; const int pid = recv_link;
if (selecter.CheckRead(links[pid].sock)) { if (selecter.CheckRead(links[pid].sock)) {
ReturnType ret = links[pid].ReadToArray(sendrecvbuf_, size); ReturnType ret = links[pid].ReadToArray(sendrecvbuf_, size);
if (ret != kSuccess) { if (ret != kSuccess) {
return ReportError(&links[pid], ret); return ReportError(&links[pid], ret);
} }
@ -823,10 +823,10 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
* \param buf the buffer to store the result * \param buf the buffer to store the result
* \param size the total size of the buffer * \param size the total size of the buffer
* \param flag flag information about the action \sa ActionSummary * \param flag flag information about the action \sa ActionSummary
* \param seqno sequence number of the action, if it is special action with flag set, * \param seqno sequence number of the action, if it is special action with flag set,
* seqno needs to be set to ActionSummary::kSpecialOp * seqno needs to be set to ActionSummary::kSpecialOp
* *
* \return if this function can return true or false * \return if this function can return true or false
* - true means buf already set to the * - true means buf already set to the
* result by recovering procedure, the action is complete, no further action is needed * result by recovering procedure, the action is complete, no further action is needed
* - false means this is the lastest action that has not yet been executed, need to execute the action * - false means this is the lastest action that has not yet been executed, need to execute the action
@ -907,7 +907,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
* plus replication of states in previous num_local_replica hops in the ring * plus replication of states in previous num_local_replica hops in the ring
* *
* The input parameters must contain the valid local states available in current nodes, * The input parameters must contain the valid local states available in current nodes,
* This function try ist best to "complete" the missing parts of local_rptr and local_chkpt * This function try ist best to "complete" the missing parts of local_rptr and local_chkpt
* If there is sufficient information in the ring, when the function returns, local_chkpt will * If there is sufficient information in the ring, when the function returns, local_chkpt will
* contain num_local_replica + 1 checkpoints (including the chkpt of this node) * contain num_local_replica + 1 checkpoints (including the chkpt of this node)
* If there is no sufficient information in the ring, this function the number of checkpoints * If there is no sufficient information in the ring, this function the number of checkpoints

View File

@ -5,7 +5,7 @@
* using TCP non-block socket and tree-shape reduction. * using TCP non-block socket and tree-shape reduction.
* *
* This implementation considers the failure of nodes * This implementation considers the failure of nodes
* *
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/ */
#ifndef RABIT_ALLREDUCE_ROBUST_H_ #ifndef RABIT_ALLREDUCE_ROBUST_H_
@ -28,13 +28,13 @@ class AllreduceRobust : public AllreduceBase {
/*! \brief shutdown the engine */ /*! \brief shutdown the engine */
virtual void Shutdown(void); virtual void Shutdown(void);
/*! /*!
* \brief set parameters to the engine * \brief set parameters to the engine
* \param name parameter name * \param name parameter name
* \param val parameter value * \param val parameter value
*/ */
virtual void SetParam(const char *name, const char *val); virtual void SetParam(const char *name, const char *val);
/*! /*!
* \brief perform in-place allreduce, on sendrecvbuf * \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe * this function is NOT thread-safe
* \param sendrecvbuf_ buffer for both sending and recving data * \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have * \param type_nbytes the unit number of bytes the type have
@ -69,14 +69,14 @@ class AllreduceRobust : public AllreduceBase {
* \return the version number of check point loaded * \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed * if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves * the p_model is not touched, user should do necessary initialization by themselves
* *
* Common usage example: * Common usage example:
* int iter = rabit::LoadCheckPoint(&model); * int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters(); * if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) { * for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce * do many things, include allreduce
* rabit::CheckPoint(model); * rabit::CheckPoint(model);
* } * }
* *
* \sa CheckPoint, VersionNumber * \sa CheckPoint, VersionNumber
*/ */
@ -85,7 +85,7 @@ class AllreduceRobust : public AllreduceBase {
/*! /*!
* \brief checkpoint the model, meaning we finished a stage of execution * \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one * every time we call check point, there is a version number which will increase by one
* *
* \param global_model pointer to the globally shared model/state * \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model * when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes * is the same in all nodes
@ -105,16 +105,16 @@ class AllreduceRobust : public AllreduceBase {
/*! /*!
* \brief This function can be used to replace CheckPoint for global_model only, * \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met(see detailed expplaination). * when certain condition is met(see detailed expplaination).
* *
* This is a "lazy" checkpoint such that only the pointer to global_model is * This is a "lazy" checkpoint such that only the pointer to global_model is
* remembered and no memory copy is taken. To use this function, the user MUST ensure that: * remembered and no memory copy is taken. To use this function, the user MUST ensure that:
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs. * The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
* In another words, global_model model can be changed only between last call of * In another words, global_model model can be changed only between last call of
* Allreduce/Broadcast and LazyCheckPoint in current version * Allreduce/Broadcast and LazyCheckPoint in current version
* *
* For example, suppose the calling sequence is: * For example, suppose the calling sequence is:
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint * LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
* *
* If user can only changes global_model in code3, then LazyCheckPoint can be used to * If user can only changes global_model in code3, then LazyCheckPoint can be used to
* improve efficiency of the program. * improve efficiency of the program.
* \param global_model pointer to the globally shared model/state * \param global_model pointer to the globally shared model/state
@ -287,6 +287,7 @@ class AllreduceRobust : public AllreduceBase {
if (seqno_.size() == 0) return -1; if (seqno_.size() == 0) return -1;
return seqno_.back(); return seqno_.back();
} }
private: private:
// sequence number of each // sequence number of each
std::vector<int> seqno_; std::vector<int> seqno_;
@ -301,14 +302,14 @@ class AllreduceRobust : public AllreduceBase {
* \brief internal consistency check function, * \brief internal consistency check function,
* use check to ensure user always call CheckPoint/LoadCheckPoint * use check to ensure user always call CheckPoint/LoadCheckPoint
* with or without local but not both, this function will set the approperiate settings * with or without local but not both, this function will set the approperiate settings
* in the first call of LoadCheckPoint/CheckPoint * in the first call of LoadCheckPoint/CheckPoint
* *
* \param with_local whether the user calls CheckPoint with local model * \param with_local whether the user calls CheckPoint with local model
*/ */
void LocalModelCheck(bool with_local); void LocalModelCheck(bool with_local);
/*! /*!
* \brief internal implementation of checkpoint, support both lazy and normal way * \brief internal implementation of checkpoint, support both lazy and normal way
* *
* \param global_model pointer to the globally shared model/state * \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model * when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes * is the same in all nodes
@ -326,10 +327,10 @@ class AllreduceRobust : public AllreduceBase {
* after this function finishes, all the messages received and sent * after this function finishes, all the messages received and sent
* before in all live links are discarded, * before in all live links are discarded,
* This allows us to get a fresh start after error has happened * This allows us to get a fresh start after error has happened
* *
* TODO(tqchen): this function is not yet functioning was not used by engine, * TODO(tqchen): this function is not yet functioning was not used by engine,
* simple resetlink and reconnect strategy is used * simple resetlink and reconnect strategy is used
* *
* \return this function can return kSuccess or kSockError * \return this function can return kSuccess or kSockError
* when kSockError is returned, it simply means there are bad sockets in the links, * when kSockError is returned, it simply means there are bad sockets in the links,
* and some link recovery proceduer is needed * and some link recovery proceduer is needed
@ -340,7 +341,7 @@ class AllreduceRobust : public AllreduceBase {
* recover links according to the error type reported * recover links according to the error type reported
* if there is no error, return true * if there is no error, return true
* \param err_type the type of error happening in the system * \param err_type the type of error happening in the system
* \return true if err_type is kSuccess, false otherwise * \return true if err_type is kSuccess, false otherwise
*/ */
bool CheckAndRecover(ReturnType err_type); bool CheckAndRecover(ReturnType err_type);
/*! /*!
@ -355,7 +356,7 @@ class AllreduceRobust : public AllreduceBase {
* \param seqno sequence number of the action, if it is special action with flag set, * \param seqno sequence number of the action, if it is special action with flag set,
* seqno needs to be set to ActionSummary::kSpecialOp * seqno needs to be set to ActionSummary::kSpecialOp
* *
* \return if this function can return true or false * \return if this function can return true or false
* - true means buf already set to the * - true means buf already set to the
* result by recovering procedure, the action is complete, no further action is needed * result by recovering procedure, the action is complete, no further action is needed
* - false means this is the lastest action that has not yet been executed, need to execute the action * - false means this is the lastest action that has not yet been executed, need to execute the action
@ -364,7 +365,7 @@ class AllreduceRobust : public AllreduceBase {
int seqno = ActionSummary::kSpecialOp); int seqno = ActionSummary::kSpecialOp);
/*! /*!
* \brief try to load check point * \brief try to load check point
* *
* This is a collaborative function called by all nodes * This is a collaborative function called by all nodes
* only the nodes with requester set to true really needs to load the check point * only the nodes with requester set to true really needs to load the check point
* other nodes acts as collaborative roles to complete this request * other nodes acts as collaborative roles to complete this request
@ -395,7 +396,7 @@ class AllreduceRobust : public AllreduceBase {
* \param p_size used to store the size of the message, for node in state kHaveData, * \param p_size used to store the size of the message, for node in state kHaveData,
* this size must be set correctly before calling the function * this size must be set correctly before calling the function
* for others, this surves as output parameter * for others, this surves as output parameter
* \param p_recvlink used to store the link current node should recv data from, if necessary * \param p_recvlink used to store the link current node should recv data from, if necessary
* this can be -1, which means current node have the data * this can be -1, which means current node have the data
* \param p_req_in used to store the resulting vector, indicating which link we should send the data to * \param p_req_in used to store the resulting vector, indicating which link we should send the data to
@ -432,7 +433,7 @@ class AllreduceRobust : public AllreduceBase {
* plus replication of states in previous num_local_replica hops in the ring * plus replication of states in previous num_local_replica hops in the ring
* *
* The input parameters must contain the valid local states available in current nodes, * The input parameters must contain the valid local states available in current nodes,
* This function try ist best to "complete" the missing parts of local_rptr and local_chkpt * This function try ist best to "complete" the missing parts of local_rptr and local_chkpt
* If there is sufficient information in the ring, when the function returns, local_chkpt will * If there is sufficient information in the ring, when the function returns, local_chkpt will
* contain num_local_replica + 1 checkpoints (including the chkpt of this node) * contain num_local_replica + 1 checkpoints (including the chkpt of this node)
* If there is no sufficient information in the ring, this function the number of checkpoints * If there is no sufficient information in the ring, this function the number of checkpoints
@ -487,7 +488,7 @@ o * the input state must exactly one saved state(local state of current node)
LinkRecord *read_link, LinkRecord *read_link,
LinkRecord *write_link); LinkRecord *write_link);
/*! /*!
* \brief run message passing algorithm on the allreduce tree * \brief run message passing algorithm on the allreduce tree
* the result is edge message stored in p_edge_in and p_edge_out * the result is edge message stored in p_edge_in and p_edge_out
* \param node_value the value associated with current node * \param node_value the value associated with current node
* \param p_edge_in used to store input message from each of the edge * \param p_edge_in used to store input message from each of the edge
@ -509,7 +510,7 @@ o * the input state must exactly one saved state(local state of current node)
inline ReturnType MsgPassing(const NodeType &node_value, inline ReturnType MsgPassing(const NodeType &node_value,
std::vector<EdgeType> *p_edge_in, std::vector<EdgeType> *p_edge_in,
std::vector<EdgeType> *p_edge_out, std::vector<EdgeType> *p_edge_out,
EdgeType (*func) EdgeType(*func)
(const NodeType &node_value, (const NodeType &node_value,
const std::vector<EdgeType> &edge_in, const std::vector<EdgeType> &edge_in,
size_t out_index)); size_t out_index));

View File

@ -3,7 +3,7 @@
* \file engine.cc * \file engine.cc
* \brief this file governs which implementation of engine we are actually using * \brief this file governs which implementation of engine we are actually using
* provides an singleton of engine interface * provides an singleton of engine interface
* *
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/ */
#define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS
@ -60,7 +60,7 @@ void Allreduce_(void *sendrecvbuf,
} }
// code for reduce handle // code for reduce handle
ReduceHandle::ReduceHandle(void) ReduceHandle::ReduceHandle(void)
: handle_(NULL), redfunc_(NULL), htype_(NULL) { : handle_(NULL), redfunc_(NULL), htype_(NULL) {
} }
ReduceHandle::~ReduceHandle(void) {} ReduceHandle::~ReduceHandle(void) {}

View File

@ -3,7 +3,7 @@
* \file engine_mpi.cc * \file engine_mpi.cc
* \brief this file gives an implementation of engine interface using MPI, * \brief this file gives an implementation of engine interface using MPI,
* this will allow rabit program to run with MPI, but do not comes with fault tolerant * this will allow rabit program to run with MPI, but do not comes with fault tolerant
* *
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS
@ -143,7 +143,7 @@ void Allreduce_(void *sendrecvbuf,
} }
// code for reduce handle // code for reduce handle
ReduceHandle::ReduceHandle(void) ReduceHandle::ReduceHandle(void)
: handle_(NULL), redfunc_(NULL), htype_(NULL) { : handle_(NULL), redfunc_(NULL), htype_(NULL) {
} }
ReduceHandle::~ReduceHandle(void) { ReduceHandle::~ReduceHandle(void) {
@ -166,7 +166,7 @@ void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
if (type_nbytes != 0) { if (type_nbytes != 0) {
MPI::Datatype *dtype = new MPI::Datatype(); MPI::Datatype *dtype = new MPI::Datatype();
if (type_nbytes % 8 == 0) { if (type_nbytes % 8 == 0) {
*dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long)); *dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long)); // NOLINT(*)
} else if (type_nbytes % 4 == 0) { } else if (type_nbytes % 4 == 0) {
*dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int)); *dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int));
} else { } else {
@ -195,7 +195,7 @@ void ReduceHandle::Allreduce(void *sendrecvbuf,
dtype->Free(); dtype->Free();
} }
if (type_nbytes % 8 == 0) { if (type_nbytes % 8 == 0) {
*dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long)); *dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long)); // NOLINT(*)
} else if (type_nbytes % 4 == 0) { } else if (type_nbytes % 4 == 0) {
*dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int)); *dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int));
} else { } else {

View File

@ -51,7 +51,7 @@ struct SockAddr {
utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name"); utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
return std::string(buf.c_str()); return std::string(buf.c_str());
} }
/*! /*!
* \brief set the address * \brief set the address
* \param url the url of the address * \param url the url of the address
* \param port the port of address * \param port the port of address
@ -83,7 +83,7 @@ struct SockAddr {
} }
}; };
/*! /*!
* \brief base class containing common operations of TCP and UDP sockets * \brief base class containing common operations of TCP and UDP sockets
*/ */
class Socket { class Socket {
@ -95,7 +95,7 @@ class Socket {
return sockfd; return sockfd;
} }
/*! /*!
* \return last error of socket operation * \return last error of socket operation
*/ */
inline static int GetLastError(void) { inline static int GetLastError(void) {
#ifdef _WIN32 #ifdef _WIN32
@ -106,7 +106,7 @@ class Socket {
} }
/*! \return whether last error was would block */ /*! \return whether last error was would block */
inline static bool LastErrorWouldBlock(void) { inline static bool LastErrorWouldBlock(void) {
int errsv = GetLastError(); int errsv = GetLastError();
#ifdef _WIN32 #ifdef _WIN32
return errsv == WSAEWOULDBLOCK; return errsv == WSAEWOULDBLOCK;
#else #else
@ -129,15 +129,15 @@ class Socket {
} }
#endif #endif
} }
/*! /*!
* \brief shutdown the socket module after use, all sockets need to be closed * \brief shutdown the socket module after use, all sockets need to be closed
*/ */
inline static void Finalize(void) { inline static void Finalize(void) {
#ifdef _WIN32 #ifdef _WIN32
WSACleanup(); WSACleanup();
#endif #endif
} }
/*! /*!
* \brief set this socket to use non-blocking mode * \brief set this socket to use non-blocking mode
* \param non_block whether set it to be non-block, if it is false * \param non_block whether set it to be non-block, if it is false
* it will set it back to block mode * it will set it back to block mode
@ -163,8 +163,8 @@ class Socket {
} }
#endif #endif
} }
/*! /*!
* \brief bind the socket to an address * \brief bind the socket to an address
* \param addr * \param addr
*/ */
inline void Bind(const SockAddr &addr) { inline void Bind(const SockAddr &addr) {
@ -173,7 +173,7 @@ class Socket {
Socket::Error("Bind"); Socket::Error("Bind");
} }
} }
/*! /*!
* \brief try bind the socket to host, from start_port to end_port * \brief try bind the socket to host, from start_port to end_port
* \param start_port starting port number to try * \param start_port starting port number to try
* \param end_port ending port number to try * \param end_port ending port number to try
@ -188,11 +188,11 @@ class Socket {
return port; return port;
} }
#if defined(_WIN32) #if defined(_WIN32)
if (WSAGetLastError() != WSAEADDRINUSE) { if (WSAGetLastError() != WSAEADDRINUSE) {
Socket::Error("TryBindHost"); Socket::Error("TryBindHost");
} }
#else #else
if (errno != EADDRINUSE) { if (errno != EADDRINUSE) {
Socket::Error("TryBindHost"); Socket::Error("TryBindHost");
} }
#endif #endif
@ -248,7 +248,7 @@ class Socket {
} }
}; };
/*! /*!
* \brief a wrapper of TCP socket that hopefully be cross platform * \brief a wrapper of TCP socket that hopefully be cross platform
*/ */
class TCPSocket : public Socket{ class TCPSocket : public Socket{
@ -261,10 +261,11 @@ class TCPSocket : public Socket{
/*! /*!
* \brief enable/disable TCP keepalive * \brief enable/disable TCP keepalive
* \param keepalive whether to set the keep alive option on * \param keepalive whether to set the keep alive option on
*/ */
inline void SetKeepAlive(bool keepalive) { inline void SetKeepAlive(bool keepalive) {
int opt = static_cast<int>(keepalive); int opt = static_cast<int>(keepalive);
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) { if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
Socket::Error("SetKeepAlive"); Socket::Error("SetKeepAlive");
} }
} }
@ -294,12 +295,12 @@ class TCPSocket : public Socket{
return TCPSocket(newfd); return TCPSocket(newfd);
} }
/*! /*!
* \brief decide whether the socket is at OOB mark * \brief decide whether the socket is at OOB mark
* \return 1 if at mark, 0 if not, -1 if an error occured * \return 1 if at mark, 0 if not, -1 if an error occured
*/ */
inline int AtMark(void) const { inline int AtMark(void) const {
#ifdef _WIN32 #ifdef _WIN32
unsigned long atmark; unsigned long atmark; // NOLINT(*)
if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1; if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
#else #else
int atmark; int atmark;
@ -307,8 +308,8 @@ class TCPSocket : public Socket{
#endif #endif
return static_cast<int>(atmark); return static_cast<int>(atmark);
} }
/*! /*!
* \brief connect to an address * \brief connect to an address
* \param addr the address to connect to * \param addr the address to connect to
* \return whether connect is successful * \return whether connect is successful
*/ */
@ -328,8 +329,8 @@ class TCPSocket : public Socket{
const char *buf = reinterpret_cast<const char*>(buf_); const char *buf = reinterpret_cast<const char*>(buf_);
return send(sockfd, buf, static_cast<sock_size_t>(len), flag); return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
} }
/*! /*!
* \brief receive data using the socket * \brief receive data using the socket
* \param buf_ the pointer to the buffer * \param buf_ the pointer to the buffer
* \param len the size of the buffer * \param len the size of the buffer
* \param flags extra flags * \param flags extra flags
@ -385,7 +386,7 @@ class TCPSocket : public Socket{
return ndone; return ndone;
} }
/*! /*!
* \brief send a string over network * \brief send a string over network
* \param str the string to be sent * \param str the string to be sent
*/ */
inline void SendStr(const std::string &str) { inline void SendStr(const std::string &str) {
@ -423,7 +424,7 @@ struct SelectHelper {
maxfd = 0; maxfd = 0;
} }
/*! /*!
* \brief add file descriptor to watch for read * \brief add file descriptor to watch for read
* \param fd file descriptor to be watched * \param fd file descriptor to be watched
*/ */
inline void WatchRead(SOCKET fd) { inline void WatchRead(SOCKET fd) {
@ -473,7 +474,7 @@ struct SelectHelper {
* \param timeout the timeout counter, can be 0, which means wait until the event happen * \param timeout the timeout counter, can be 0, which means wait until the event happen
* \return 1 if success, 0 if timeout, and -1 if error occurs * \return 1 if success, 0 if timeout, and -1 if error occurs
*/ */
inline static int WaitExcept(SOCKET fd, long timeout = 0) { inline static int WaitExcept(SOCKET fd, long timeout = 0) { // NOLINT(*)
fd_set wait_set; fd_set wait_set;
FD_ZERO(&wait_set); FD_ZERO(&wait_set);
FD_SET(fd, &wait_set); FD_SET(fd, &wait_set);
@ -486,10 +487,10 @@ struct SelectHelper {
* \param select_write whether to watch for write event * \param select_write whether to watch for write event
* \param select_except whether to watch for exception event * \param select_except whether to watch for exception event
* \param timeout specify timeout in micro-seconds(ms) if equals 0, means select will always block * \param timeout specify timeout in micro-seconds(ms) if equals 0, means select will always block
* \return number of active descriptors selected, * \return number of active descriptors selected,
* return -1 if error occurs * return -1 if error occurs
*/ */
inline int Select(long timeout = 0) { inline int Select(long timeout = 0) { // NOLINT(*)
int ret = Select_(static_cast<int>(maxfd + 1), int ret = Select_(static_cast<int>(maxfd + 1),
&read_set, &write_set, &except_set, timeout); &read_set, &write_set, &except_set, timeout);
if (ret == -1) { if (ret == -1) {
@ -500,7 +501,7 @@ struct SelectHelper {
private: private:
inline static int Select_(int maxfd, fd_set *rfds, inline static int Select_(int maxfd, fd_set *rfds,
fd_set *wfds, fd_set *efds, long timeout) { fd_set *wfds, fd_set *efds, long timeout) { // NOLINT(*)
#if !defined(_WIN32) #if !defined(_WIN32)
utils::Assert(maxfd < FD_SETSIZE, "maxdf must be smaller than FDSETSIZE"); utils::Assert(maxfd < FD_SETSIZE, "maxdf must be smaller than FDSETSIZE");
#endif #endif

View File

@ -1,7 +1,7 @@
# this is a makefile used to show testcases of rabit # this is a makefile used to show testcases of rabit
.PHONY: all .PHONY: all
all: all: model_recover_10_10k model_recover_10_10k_die_same
# this experiment test recovery with actually process exit, use keepalive to keep program alive # this experiment test recovery with actually process exit, use keepalive to keep program alive
model_recover_10_10k: model_recover_10_10k:

View File

@ -3,6 +3,7 @@ Python interface for rabit
Reliable Allreduce and Broadcast Library Reliable Allreduce and Broadcast Library
Author: Tianqi Chen Author: Tianqi Chen
""" """
# pylint: disable=unused-argument,invalid-name,global-statement,dangerous-default-value,
import cPickle as pickle import cPickle as pickle
import ctypes import ctypes
import os import os
@ -17,10 +18,12 @@ else:
rbtlib = None rbtlib = None
# load in xgboost library # load in xgboost library
def loadlib__(lib = 'standard'): def loadlib__(lib='standard'):
"""Load rabit library"""
global rbtlib global rbtlib
if rbtlib != None: if rbtlib != None:
warnings.Warn('rabit.int call was ignored because it has already been initialized', level = 2) warnings.warn('rabit.int call was ignored because it has'\
' already been initialized', level=2)
return return
if lib == 'standard': if lib == 'standard':
rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH % '') rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH % '')
@ -35,6 +38,7 @@ def loadlib__(lib = 'standard'):
rbtlib.RabitVersionNumber.restype = ctypes.c_int rbtlib.RabitVersionNumber.restype = ctypes.c_int
def unloadlib__(): def unloadlib__():
"""Unload rabit library"""
global rbtlib global rbtlib
del rbtlib del rbtlib
rbtlib = None rbtlib = None
@ -45,13 +49,13 @@ MIN = 1
SUM = 2 SUM = 2
BITOR = 3 BITOR = 3
def check_err__(): def check_err__():
""" """
reserved function used to check error reserved function used to check error
""" """
return return
def init(args = sys.argv, lib = 'standard'): def init(args=sys.argv, lib='standard'):
""" """
intialize the rabit module, call this once before using anything intialize the rabit module, call this once before using anything
Arguments: Arguments:
@ -69,7 +73,7 @@ def init(args = sys.argv, lib = 'standard'):
def finalize(): def finalize():
""" """
finalize the rabit engine, call this function after you finished all jobs finalize the rabit engine, call this function after you finished all jobs
""" """
rbtlib.RabitFinalize() rbtlib.RabitFinalize()
check_err__() check_err__()
@ -132,7 +136,7 @@ def broadcast(data, root):
print '@node[%d] after-broadcast: s=\"%s\"' % (rank, str(s)) print '@node[%d] after-broadcast: s=\"%s\"' % (rank, str(s))
rabit.finalize() rabit.finalize()
``` ```
Arguments: Arguments:
data: anytype that can be pickled data: anytype that can be pickled
input data, if current rank does not equal root, this can be None input data, if current rank does not equal root, this can be None
@ -145,12 +149,12 @@ def broadcast(data, root):
length = ctypes.c_ulong() length = ctypes.c_ulong()
if root == rank: if root == rank:
assert data is not None, 'need to pass in data when broadcasting' assert data is not None, 'need to pass in data when broadcasting'
s = pickle.dumps(data, protocol = pickle.HIGHEST_PROTOCOL) s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
length.value = len(s) length.value = len(s)
# run first broadcast # run first broadcast
rbtlib.RabitBroadcast(ctypes.byref(length), rbtlib.RabitBroadcast(ctypes.byref(length),
ctypes.sizeof(ctypes.c_ulong), ctypes.sizeof(ctypes.c_ulong),
root) root)
check_err__() check_err__()
if root != rank: if root != rank:
dptr = (ctypes.c_char * length.value)() dptr = (ctypes.c_char * length.value)()
@ -179,18 +183,19 @@ DTYPE_ENUM__ = {
np.dtype('float64') : 7 np.dtype('float64') : 7
} }
def allreduce(data, op, prepare_fun = None): def allreduce(data, op, prepare_fun=None):
""" """
perform allreduce, return the result, this function is not thread-safe perform allreduce, return the result, this function is not thread-safe
Arguments: Arguments:
data: numpy ndarray data: numpy ndarray
input data input data
op: int op: int
reduction operators, can be MIN, MAX, SUM, BITOR reduction operators, can be MIN, MAX, SUM, BITOR
prepare_fun: lambda data prepare_fun: lambda data
Lazy preprocessing function, if it is not None, prepare_fun(data) Lazy preprocessing function, if it is not None, prepare_fun(data)
will be called by the function before performing allreduce, to intialize the data will be called by the function before performing allreduce, to intialize the data
If the result of Allreduce can be recovered directly, then prepare_fun will NOT be called If the result of Allreduce can be recovered directly,
then prepare_fun will NOT be called
Returns: Returns:
the result of allreduce, have same shape as data the result of allreduce, have same shape as data
""" """
@ -206,12 +211,13 @@ def allreduce(data, op, prepare_fun = None):
buf.size, DTYPE_ENUM__[buf.dtype], buf.size, DTYPE_ENUM__[buf.dtype],
op, None, None) op, None, None)
else: else:
PFUNC = ctypes.CFUNCTYPE(None, ctypes.c_void_p) func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
def pfunc(args): def pfunc(args):
"""prepare function."""
prepare_fun(data) prepare_fun(data)
rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype], buf.size, DTYPE_ENUM__[buf.dtype],
op, PFUNC(pfunc), None) op, func_ptr(pfunc), None)
check_err__() check_err__()
return buf return buf
@ -229,49 +235,49 @@ def load_model__(ptr, length):
data = (ctypes.c_char * length).from_address(ctypes.addressof(ptr.contents)) data = (ctypes.c_char * length).from_address(ctypes.addressof(ptr.contents))
return pickle.loads(data.raw) return pickle.loads(data.raw)
def load_checkpoint(with_local = False): def load_checkpoint(with_local=False):
""" """
load latest check point load latest check point
Arguments: Arguments:
with_local: boolean [default = False] with_local: boolean [default = False]
whether the checkpoint contains local model whether the checkpoint contains local model
Returns: Returns:
if with_local: return (version, gobal_model, local_model) if with_local: return (version, gobal_model, local_model)
else return (version, gobal_model) else return (version, gobal_model)
if returned version == 0, this means no model has been CheckPointed if returned version == 0, this means no model has been CheckPointed
and global_model, local_model returned will be None and global_model, local_model returned will be None
""" """
gp = ctypes.POINTER(ctypes.c_char)() gptr = ctypes.POINTER(ctypes.c_char)()
global_len = ctypes.c_ulong() global_len = ctypes.c_ulong()
if with_local: if with_local:
lp = ctypes.POINTER(ctypes.c_char)() lptr = ctypes.POINTER(ctypes.c_char)()
local_len = ctypes.c_ulong() local_len = ctypes.c_ulong()
version = rbtlib.RabitLoadCheckPoint( version = rbtlib.RabitLoadCheckPoint(
ctypes.byref(gp), ctypes.byref(gptr),
ctypes.byref(global_len), ctypes.byref(global_len),
ctypes.byref(lp), ctypes.byref(lptr),
ctypes.byref(local_len)) ctypes.byref(local_len))
check_err__() check_err__()
if version == 0: if version == 0:
return (version, None, None) return (version, None, None)
return (version, return (version,
load_model__(gp, global_len.value), load_model__(gptr, global_len.value),
load_model__(lp, local_len.value)) load_model__(lptr, local_len.value))
else: else:
version = rbtlib.RabitLoadCheckPoint( version = rbtlib.RabitLoadCheckPoint(
ctypes.byref(gp), ctypes.byref(gptr),
ctypes.byref(global_len), ctypes.byref(global_len),
None, None) None, None)
check_err__() check_err__()
if version == 0: if version == 0:
return (version, None) return (version, None)
return (version, return (version,
load_model__(gp, global_len.value)) load_model__(gptr, global_len.value))
def checkpoint(global_model, local_model = None): def checkpoint(global_model, local_model=None):
""" """
checkpoint the model, meaning we finished a stage of execution checkpoint the model, meaning we finished a stage of execution
every time we call check point, there is a version number which will increase by one every time we call check point, there is a version number which will increase by one
Arguments: Arguments:
global_model: anytype that can be pickled global_model: anytype that can be pickled
@ -285,16 +291,17 @@ def checkpoint(global_model, local_model = None):
while global_model do not need explicit replication. while global_model do not need explicit replication.
It is recommended to use global_model if possible It is recommended to use global_model if possible
""" """
sg = pickle.dumps(global_model) sglobal = pickle.dumps(global_model)
if local_model is None: if local_model is None:
rbtlib.RabitCheckPoint(sg, len(sg), None, 0) rbtlib.RabitCheckPoint(sglobal, len(sglobal), None, 0)
check_err__() check_err__()
del sg; del sglobal
else: else:
sl = pickle.dumps(local_model) slocal = pickle.dumps(local_model)
rbtlib.RabitCheckPoint(sg, len(sg), sl, len(sl)) rbtlib.RabitCheckPoint(sglobal, len(sglobal), slocal, len(slocal))
check_err__() check_err__()
del sl; del sg; del slocal
del sglobal
def version_number(): def version_number():
""" """

View File

@ -1,3 +1,4 @@
// Copyright by Contributors
// implementations in ctypes // implementations in ctypes
#define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE #define _CRT_SECURE_NO_DEPRECATE
@ -28,7 +29,7 @@ struct FHelper<op::BitOR, DType> {
void (*prepare_fun)(void *arg), void (*prepare_fun)(void *arg),
void *prepare_arg) { void *prepare_arg) {
utils::Error("DataType does not support bitwise or operation"); utils::Error("DataType does not support bitwise or operation");
} }
}; };
template<typename OP> template<typename OP>
inline void Allreduce_(void *sendrecvbuf_, inline void Allreduce_(void *sendrecvbuf_,
@ -60,12 +61,12 @@ inline void Allreduce_(void *sendrecvbuf_,
return; return;
case kLong: case kLong:
rabit::Allreduce<OP> rabit::Allreduce<OP>
(static_cast<long*>(sendrecvbuf_), (static_cast<long*>(sendrecvbuf_), // NOLINT(*)
count, prepare_fun, prepare_arg); count, prepare_fun, prepare_arg);
return; return;
case kULong: case kULong:
rabit::Allreduce<OP> rabit::Allreduce<OP>
(static_cast<unsigned long*>(sendrecvbuf_), (static_cast<unsigned long*>(sendrecvbuf_), // NOLINT(*)
count, prepare_fun, prepare_arg); count, prepare_fun, prepare_arg);
return; return;
case kFloat: case kFloat:
@ -135,7 +136,7 @@ struct ReadWrapper : public Serializable {
} }
virtual void Save(Stream *fo) const { virtual void Save(Stream *fo) const {
utils::Error("not implemented"); utils::Error("not implemented");
} }
}; };
struct WriteWrapper : public Serializable { struct WriteWrapper : public Serializable {
const char *data; const char *data;
@ -179,7 +180,7 @@ extern "C" {
if (s.length() > max_len) { if (s.length() > max_len) {
s.resize(max_len - 1); s.resize(max_len - 1);
} }
strcpy(out_name, s.c_str()); strcpy(out_name, s.c_str()); // NOLINT(*)
*out_len = static_cast<rbt_ulong>(s.length()); *out_len = static_cast<rbt_ulong>(s.length());
} }
void RabitBroadcast(void *sendrecv_data, void RabitBroadcast(void *sendrecv_data,
@ -218,7 +219,7 @@ extern "C" {
*out_local_model = BeginPtr(local_buffer); *out_local_model = BeginPtr(local_buffer);
*out_local_len = static_cast<rbt_ulong>(local_buffer.length()); *out_local_len = static_cast<rbt_ulong>(local_buffer.length());
} }
return version; return version;
} }
void RabitCheckPoint(const char *global_model, void RabitCheckPoint(const char *global_model,
rbt_ulong global_len, rbt_ulong global_len,

View File

@ -1,18 +1,19 @@
#ifndef RABIT_WRAPPER_H_
#define RABIT_WRAPPER_H_
/*! /*!
* Copyright by Contributors
* \file rabit_wrapper.h * \file rabit_wrapper.h
* \author Tianqi Chen * \author Tianqi Chen
* \brief a C style wrapper of rabit * \brief a C style wrapper of rabit
* can be used to create wrapper of other languages * can be used to create wrapper of other languages
*/ */
#ifndef RABIT_WRAPPER_H_
#define RABIT_WRAPPER_H_
#ifdef _MSC_VER #ifdef _MSC_VER
#define RABIT_DLL __declspec(dllexport) #define RABIT_DLL __declspec(dllexport)
#else #else
#define RABIT_DLL #define RABIT_DLL
#endif #endif
// manually define unsign long // manually define unsign long
typedef unsigned long rbt_ulong; typedef unsigned long rbt_ulong; // NOLINT(*)
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
@ -23,8 +24,8 @@ extern "C" {
* \param argv the array of input arguments * \param argv the array of input arguments
*/ */
RABIT_DLL void RabitInit(int argc, char *argv[]); RABIT_DLL void RabitInit(int argc, char *argv[]);
/*! /*!
* \brief finalize the rabit engine, call this function after you finished all jobs * \brief finalize the rabit engine, call this function after you finished all jobs
*/ */
RABIT_DLL void RabitFinalize(void); RABIT_DLL void RabitFinalize(void);
/*! \brief get rank of current process */ /*! \brief get rank of current process */
@ -37,9 +38,9 @@ extern "C" {
* the user who monitors the tracker * the user who monitors the tracker
* \param msg the message to be printed * \param msg the message to be printed
*/ */
RABIT_DLL void RabitTrackerPrint(const char *msg); RABIT_DLL void RabitTrackerPrint(const char *msg);
/*! /*!
* \brief get name of processor * \brief get name of processor
* \param out_name hold output string * \param out_name hold output string
* \param out_len hold length of output string * \param out_len hold length of output string
* \param max_len maximum buffer length of input * \param max_len maximum buffer length of input
@ -50,7 +51,7 @@ extern "C" {
/*! /*!
* \brief broadcast an memory region to all others from root * \brief broadcast an memory region to all others from root
* *
* Example: int a = 1; Broadcast(&a, sizeof(a), root); * Example: int a = 1; Broadcast(&a, sizeof(a), root);
* \param sendrecv_data the pointer to send or recive buffer, * \param sendrecv_data the pointer to send or recive buffer,
* \param size the size of the data * \param size the size of the data
* \param root the root of process * \param root the root of process
@ -58,7 +59,7 @@ extern "C" {
RABIT_DLL void RabitBroadcast(void *sendrecv_data, RABIT_DLL void RabitBroadcast(void *sendrecv_data,
rbt_ulong size, int root); rbt_ulong size, int root);
/*! /*!
* \brief perform in-place allreduce, on sendrecvbuf * \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe * this function is NOT thread-safe
* *
* Example Usage: the following code gives sum of the result * Example Usage: the following code gives sum of the result
@ -81,14 +82,14 @@ extern "C" {
int enum_op, int enum_op,
void (*prepare_fun)(void *arg), void (*prepare_fun)(void *arg),
void *prepare_arg); void *prepare_arg);
/*! /*!
* \brief load latest check point * \brief load latest check point
* \param out_global_model hold output of serialized global_model * \param out_global_model hold output of serialized global_model
* \param out_global_len the output length of serialized global model * \param out_global_len the output length of serialized global model
* \param out_local_model hold output of serialized local_model, can be NULL * \param out_local_model hold output of serialized local_model, can be NULL
* \param out_local_len the output length of serialized local model, can be NULL * \param out_local_len the output length of serialized local model, can be NULL
* *
* \return the version number of check point loaded * \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed * if returned version == 0, this means no model has been CheckPointed
* nothing will be touched * nothing will be touched
@ -100,7 +101,7 @@ extern "C" {
/*! /*!
* \brief checkpoint the model, meaning we finished a stage of execution * \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one * every time we call check point, there is a version number which will increase by one
* *
* \param global_model hold content of serialized global_model * \param global_model hold content of serialized global_model
* \param global_len the content length of serialized global model * \param global_len the content length of serialized global model
* \param local_model hold content of serialized local_model, can be NULL * \param local_model hold content of serialized local_model, can be NULL
@ -122,4 +123,4 @@ extern "C" {
#ifdef __cplusplus #ifdef __cplusplus
} // C } // C
#endif #endif
#endif // XGBOOST_WRAPPER_H_ #endif // RABIT_WRAPPER_H_