lint and travis
This commit is contained in:
parent
ceedf4ea96
commit
3cc49ad0e8
1
.gitignore
vendored
1
.gitignore
vendored
@ -34,3 +34,4 @@
|
|||||||
*tmp*
|
*tmp*
|
||||||
*.rabit
|
*.rabit
|
||||||
*.mock
|
*.mock
|
||||||
|
dmlc-core
|
||||||
|
|||||||
48
.travis.yml
Normal file
48
.travis.yml
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# disable sudo to use container based build
|
||||||
|
sudo: false
|
||||||
|
|
||||||
|
# Use Build Matrix to do lint and build seperately
|
||||||
|
env:
|
||||||
|
matrix:
|
||||||
|
- TASK=lint LINT_LANG=cpp
|
||||||
|
- TASK=lint LINT_LANG=python
|
||||||
|
- TASK=doc
|
||||||
|
- TASK=build CXX=g++
|
||||||
|
|
||||||
|
# dependent apt packages
|
||||||
|
addons:
|
||||||
|
apt:
|
||||||
|
packages:
|
||||||
|
- doxygen
|
||||||
|
- wget
|
||||||
|
- git
|
||||||
|
- libcurl4-openssl-dev
|
||||||
|
- unzip
|
||||||
|
|
||||||
|
before_install:
|
||||||
|
- git clone https://github.com/dmlc/dmlc-core
|
||||||
|
- export TRAVIS=dmlc-core/scripts/travis/
|
||||||
|
- source ${TRAVIS}/travis_setup_env.sh
|
||||||
|
|
||||||
|
install:
|
||||||
|
- pip install cpplint pylint --user `whoami`
|
||||||
|
|
||||||
|
script: scripts/travis_script.sh
|
||||||
|
|
||||||
|
|
||||||
|
before_cache:
|
||||||
|
- ${TRAVIS}/travis_before_cache.sh
|
||||||
|
|
||||||
|
|
||||||
|
cache:
|
||||||
|
directories:
|
||||||
|
- ${HOME}/.cache/usr
|
||||||
|
|
||||||
|
|
||||||
|
notifications:
|
||||||
|
# Emails are sent to the committer's git-configured email address by default,
|
||||||
|
email:
|
||||||
|
on_success: change
|
||||||
|
on_failure: always
|
||||||
|
|
||||||
|
|
||||||
24
Makefile
24
Makefile
@ -3,14 +3,18 @@ export CXX = g++
|
|||||||
endif
|
endif
|
||||||
export MPICXX = mpicxx
|
export MPICXX = mpicxx
|
||||||
export LDFLAGS= -Llib -lrt
|
export LDFLAGS= -Llib -lrt
|
||||||
export WARNFLAGS= -Wall -Wextra -Wno-unused-parameter -Wno-unknown-pragmas
|
export WARNFLAGS= -Wall -Wextra -Wno-unused-parameter -Wno-unknown-pragmas
|
||||||
export CFLAGS = -O3 -msse2 $(WARNFLAGS)
|
export CFLAGS = -O3 -msse2 $(WARNFLAGS)
|
||||||
|
|
||||||
ifndef WITH_FPIC
|
ifndef WITH_FPIC
|
||||||
WITH_FPIC = 1
|
WITH_FPIC = 1
|
||||||
endif
|
endif
|
||||||
ifeq ($(WITH_FPIC), 1)
|
ifeq ($(WITH_FPIC), 1)
|
||||||
CFLAGS += -fPIC
|
CFLAGS += -fPIC
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifndef LINT_LANG
|
||||||
|
LINT_LANG="all"
|
||||||
endif
|
endif
|
||||||
|
|
||||||
# build path
|
# build path
|
||||||
@ -22,7 +26,9 @@ OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(B
|
|||||||
SLIB= wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so wrapper/librabit_wrapper_mpi.so
|
SLIB= wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so wrapper/librabit_wrapper_mpi.so
|
||||||
ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a lib/librabit_mock.a lib/librabit_base.a
|
ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a lib/librabit_mock.a lib/librabit_base.a
|
||||||
HEADERS=src/*.h include/*.h include/rabit/*.h
|
HEADERS=src/*.h include/*.h include/rabit/*.h
|
||||||
.PHONY: clean all install mpi python
|
DMLC=dmlc-core
|
||||||
|
|
||||||
|
.PHONY: clean all install mpi python lint doc
|
||||||
|
|
||||||
all: lib/librabit.a lib/librabit_mock.a wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so lib/librabit_base.a
|
all: lib/librabit.a lib/librabit_mock.a wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so lib/librabit_base.a
|
||||||
mpi: lib/librabit_mpi.a wrapper/librabit_wrapper_mpi.so
|
mpi: lib/librabit_mpi.a wrapper/librabit_wrapper_mpi.so
|
||||||
@ -47,10 +53,10 @@ wrapper/librabit_wrapper.so: $(BPATH)/rabit_wrapper.o lib/librabit.a
|
|||||||
wrapper/librabit_wrapper_mock.so: $(BPATH)/rabit_wrapper.o lib/librabit_mock.a
|
wrapper/librabit_wrapper_mock.so: $(BPATH)/rabit_wrapper.o lib/librabit_mock.a
|
||||||
wrapper/librabit_wrapper_mpi.so: $(BPATH)/rabit_wrapper.o lib/librabit_mpi.a
|
wrapper/librabit_wrapper_mpi.so: $(BPATH)/rabit_wrapper.o lib/librabit_mpi.a
|
||||||
|
|
||||||
$(OBJ) :
|
$(OBJ) :
|
||||||
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
||||||
|
|
||||||
$(MPIOBJ) :
|
$(MPIOBJ) :
|
||||||
$(MPICXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
$(MPICXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
||||||
|
|
||||||
$(ALIB):
|
$(ALIB):
|
||||||
@ -59,6 +65,12 @@ $(ALIB):
|
|||||||
$(SLIB) :
|
$(SLIB) :
|
||||||
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS)
|
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS)
|
||||||
|
|
||||||
|
lint:
|
||||||
|
$(DMLC)/scripts/lint.py rabit $(LINT_LANG) src include wrapper
|
||||||
|
|
||||||
|
doc:
|
||||||
|
cd include; doxygen ../doc/Doxyfile; cd -
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
$(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~ wrapper/*~
|
$(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~ wrapper/*~
|
||||||
|
|
||||||
|
|||||||
@ -101,8 +101,8 @@ FILE_PATTERNS =
|
|||||||
RECURSIVE = NO
|
RECURSIVE = NO
|
||||||
EXCLUDE =
|
EXCLUDE =
|
||||||
EXCLUDE_SYMLINKS = NO
|
EXCLUDE_SYMLINKS = NO
|
||||||
EXCLUDE_PATTERNS = *-inl.hpp
|
EXCLUDE_PATTERNS = *-inl.hpp
|
||||||
EXCLUDE_SYMBOLS =
|
EXCLUDE_SYMBOLS =
|
||||||
EXAMPLE_PATH =
|
EXAMPLE_PATH =
|
||||||
EXAMPLE_PATTERNS =
|
EXAMPLE_PATTERNS =
|
||||||
EXAMPLE_RECURSIVE = NO
|
EXAMPLE_RECURSIVE = NO
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
cd ../include
|
cd ../include
|
||||||
doxygen ../doc/Doxyfile
|
doxygen ../doc/Doxyfile
|
||||||
cd ../doc
|
cd ../doc
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
// include uint64_t only to make io standalone
|
// include uint64_t only to make io standalone
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
|
/*! \brief uint64 */
|
||||||
typedef unsigned __int64 uint64_t;
|
typedef unsigned __int64 uint64_t;
|
||||||
#else
|
#else
|
||||||
#include <inttypes.h>
|
#include <inttypes.h>
|
||||||
@ -24,7 +25,7 @@ namespace dmlc {
|
|||||||
/*!
|
/*!
|
||||||
* \brief interface of stream I/O for serialization
|
* \brief interface of stream I/O for serialization
|
||||||
*/
|
*/
|
||||||
class Stream {
|
class Stream { // NOLINT(*)
|
||||||
public:
|
public:
|
||||||
/*!
|
/*!
|
||||||
* \brief reads data from a stream
|
* \brief reads data from a stream
|
||||||
@ -71,7 +72,7 @@ class Stream {
|
|||||||
/*!
|
/*!
|
||||||
* \brief writes a string
|
* \brief writes a string
|
||||||
* \param str the string to be written/serialized
|
* \param str the string to be written/serialized
|
||||||
*/
|
*/
|
||||||
inline void Write(const std::string &str);
|
inline void Write(const std::string &str);
|
||||||
/*!
|
/*!
|
||||||
* \brief loads a string
|
* \brief loads a string
|
||||||
@ -94,7 +95,7 @@ class SeekStream: public Stream {
|
|||||||
* \brief generic factory function
|
* \brief generic factory function
|
||||||
* create an SeekStream for read only,
|
* create an SeekStream for read only,
|
||||||
* the stream will close the underlying files upon deletion
|
* the stream will close the underlying files upon deletion
|
||||||
* error will be reported and the system will exit when create failed
|
* error will be reported and the system will exit when create failed
|
||||||
* \param uri the uri of the input currently we support
|
* \param uri the uri of the input currently we support
|
||||||
* hdfs://, s3://, and file:// by default file:// will be used
|
* hdfs://, s3://, and file:// by default file:// will be used
|
||||||
* \param allow_null whether NULL can be returned, or directly report error
|
* \param allow_null whether NULL can be returned, or directly report error
|
||||||
@ -107,12 +108,12 @@ class SeekStream: public Stream {
|
|||||||
/*! \brief interface for serializable objects */
|
/*! \brief interface for serializable objects */
|
||||||
class Serializable {
|
class Serializable {
|
||||||
public:
|
public:
|
||||||
/*!
|
/*!
|
||||||
* \brief load the model from a stream
|
* \brief load the model from a stream
|
||||||
* \param fi stream where to load the model from
|
* \param fi stream where to load the model from
|
||||||
*/
|
*/
|
||||||
virtual void Load(Stream *fi) = 0;
|
virtual void Load(Stream *fi) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief saves the model to a stream
|
* \brief saves the model to a stream
|
||||||
* \param fo stream where to save the model to
|
* \param fo stream where to save the model to
|
||||||
*/
|
*/
|
||||||
@ -123,7 +124,7 @@ class Serializable {
|
|||||||
* \brief input split creates that allows reading
|
* \brief input split creates that allows reading
|
||||||
* of records from split of data,
|
* of records from split of data,
|
||||||
* independent part that covers all the dataset
|
* independent part that covers all the dataset
|
||||||
*
|
*
|
||||||
* see InputSplit::Create for definition of record
|
* see InputSplit::Create for definition of record
|
||||||
*/
|
*/
|
||||||
class InputSplit {
|
class InputSplit {
|
||||||
@ -141,7 +142,7 @@ class InputSplit {
|
|||||||
* this is a hint so may not be enforced,
|
* this is a hint so may not be enforced,
|
||||||
* but InputSplit will try adjust its internal buffer
|
* but InputSplit will try adjust its internal buffer
|
||||||
* size to the hinted value
|
* size to the hinted value
|
||||||
* \param chunk_size the chunk size
|
* \param chunk_size the chunk size
|
||||||
*/
|
*/
|
||||||
virtual void HintChunkSize(size_t chunk_size) {}
|
virtual void HintChunkSize(size_t chunk_size) {}
|
||||||
/*! \brief reset the position of InputSplit to beginning */
|
/*! \brief reset the position of InputSplit to beginning */
|
||||||
@ -150,7 +151,7 @@ class InputSplit {
|
|||||||
* \brief get the next record, the returning value
|
* \brief get the next record, the returning value
|
||||||
* is valid until next call to NextRecord or NextChunk
|
* is valid until next call to NextRecord or NextChunk
|
||||||
* caller can modify the memory content of out_rec
|
* caller can modify the memory content of out_rec
|
||||||
*
|
*
|
||||||
* For text, out_rec contains a single line
|
* For text, out_rec contains a single line
|
||||||
* For recordio, out_rec contains one record content(with header striped)
|
* For recordio, out_rec contains one record content(with header striped)
|
||||||
*
|
*
|
||||||
@ -161,11 +162,11 @@ class InputSplit {
|
|||||||
*/
|
*/
|
||||||
virtual bool NextRecord(Blob *out_rec) = 0;
|
virtual bool NextRecord(Blob *out_rec) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief get a chunk of memory that can contain multiple records,
|
* \brief get a chunk of memory that can contain multiple records,
|
||||||
* the caller needs to parse the content of the resulting chunk,
|
* the caller needs to parse the content of the resulting chunk,
|
||||||
* for text file, out_chunk can contain data of multiple lines
|
* for text file, out_chunk can contain data of multiple lines
|
||||||
* for recordio, out_chunk can contain multiple records(including headers)
|
* for recordio, out_chunk can contain multiple records(including headers)
|
||||||
*
|
*
|
||||||
* This function ensures there won't be partial record in the chunk
|
* This function ensures there won't be partial record in the chunk
|
||||||
* caller can modify the memory content of out_chunk,
|
* caller can modify the memory content of out_chunk,
|
||||||
* the memory is valid until next call to NextRecord or NextChunk
|
* the memory is valid until next call to NextRecord or NextChunk
|
||||||
@ -192,9 +193,10 @@ class InputSplit {
|
|||||||
* List of possible types: "text", "recordio"
|
* List of possible types: "text", "recordio"
|
||||||
* - "text":
|
* - "text":
|
||||||
* text file, each line is treated as a record
|
* text file, each line is treated as a record
|
||||||
* input split will split on \n or \r
|
* input split will split on '\\n' or '\\r'
|
||||||
* - "recordio":
|
* - "recordio":
|
||||||
* binary recordio file, see recordio.h
|
* binary recordio file, see recordio.h
|
||||||
|
* \return a new input split
|
||||||
* \sa InputSplit::Type
|
* \sa InputSplit::Type
|
||||||
*/
|
*/
|
||||||
static InputSplit* Create(const char *uri,
|
static InputSplit* Create(const char *uri,
|
||||||
@ -224,7 +226,7 @@ class ostream : public std::basic_ostream<char> {
|
|||||||
* \param buffer_size internal streambuf size
|
* \param buffer_size internal streambuf size
|
||||||
*/
|
*/
|
||||||
explicit ostream(Stream *stream,
|
explicit ostream(Stream *stream,
|
||||||
size_t buffer_size = 1 << 10)
|
size_t buffer_size = (1 << 10))
|
||||||
: std::basic_ostream<char>(NULL), buf_(buffer_size) {
|
: std::basic_ostream<char>(NULL), buf_(buffer_size) {
|
||||||
this->set_stream(stream);
|
this->set_stream(stream);
|
||||||
}
|
}
|
||||||
@ -240,7 +242,7 @@ class ostream : public std::basic_ostream<char> {
|
|||||||
buf_.set_stream(stream);
|
buf_.set_stream(stream);
|
||||||
this->rdbuf(&buf_);
|
this->rdbuf(&buf_);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// internal streambuf
|
// internal streambuf
|
||||||
class OutBuf : public std::streambuf {
|
class OutBuf : public std::streambuf {
|
||||||
@ -251,7 +253,7 @@ class ostream : public std::basic_ostream<char> {
|
|||||||
}
|
}
|
||||||
// set stream to the buffer
|
// set stream to the buffer
|
||||||
inline void set_stream(Stream *stream);
|
inline void set_stream(Stream *stream);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/*! \brief internal stream by StreamBuf */
|
/*! \brief internal stream by StreamBuf */
|
||||||
Stream *stream_;
|
Stream *stream_;
|
||||||
@ -287,7 +289,7 @@ class istream : public std::basic_istream<char> {
|
|||||||
* \param buffer_size internal buffer size
|
* \param buffer_size internal buffer size
|
||||||
*/
|
*/
|
||||||
explicit istream(Stream *stream,
|
explicit istream(Stream *stream,
|
||||||
size_t buffer_size = 1 << 10)
|
size_t buffer_size = (1 << 10))
|
||||||
: std::basic_istream<char>(NULL), buf_(buffer_size) {
|
: std::basic_istream<char>(NULL), buf_(buffer_size) {
|
||||||
this->set_stream(stream);
|
this->set_stream(stream);
|
||||||
}
|
}
|
||||||
@ -325,7 +327,7 @@ class istream : public std::basic_istream<char> {
|
|||||||
Stream *stream_;
|
Stream *stream_;
|
||||||
/*! \brief how many bytes we read so far */
|
/*! \brief how many bytes we read so far */
|
||||||
size_t bytes_read_;
|
size_t bytes_read_;
|
||||||
/*! \brief internal buffer */
|
/*! \brief internal buffer */
|
||||||
std::vector<char> buffer_;
|
std::vector<char> buffer_;
|
||||||
// override underflow
|
// override underflow
|
||||||
inline int_type underflow();
|
inline int_type underflow();
|
||||||
@ -402,7 +404,7 @@ inline int ostream::OutBuf::overflow(int c) {
|
|||||||
// implementations for istream
|
// implementations for istream
|
||||||
inline void istream::InBuf::set_stream(Stream *stream) {
|
inline void istream::InBuf::set_stream(Stream *stream) {
|
||||||
stream_ = stream;
|
stream_ = stream;
|
||||||
this->setg(&buffer_[0], &buffer_[0], &buffer_[0]);
|
this->setg(&buffer_[0], &buffer_[0], &buffer_[0]);
|
||||||
}
|
}
|
||||||
inline int istream::InBuf::underflow() {
|
inline int istream::InBuf::underflow() {
|
||||||
char *bhead = &buffer_[0];
|
char *bhead = &buffer_[0];
|
||||||
|
|||||||
@ -8,12 +8,18 @@
|
|||||||
* rabit.h and serializable.h is all what the user needs to use the rabit interface
|
* rabit.h and serializable.h is all what the user needs to use the rabit interface
|
||||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||||
*/
|
*/
|
||||||
#ifndef RABIT_RABIT_H_
|
#ifndef RABIT_RABIT_H_ // NOLINT(*)
|
||||||
#define RABIT_RABIT_H_
|
#define RABIT_RABIT_H_ // NOLINT(*)
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
// whether or not use c++11 support
|
||||||
|
#ifndef DMLC_USE_CXX11
|
||||||
|
#define DMLC_USE_CXX11 (defined(__GXX_EXPERIMENTAL_CXX0X__) ||\
|
||||||
|
__cplusplus >= 201103L || defined(_MSC_VER))
|
||||||
|
#endif
|
||||||
// optionally support of lambda functions in C++11, if available
|
// optionally support of lambda functions in C++11, if available
|
||||||
#if __cplusplus >= 201103L
|
#if DMLC_USE_CXX11
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#endif // C++11
|
#endif // C++11
|
||||||
// contains definition of Serializable
|
// contains definition of Serializable
|
||||||
@ -56,8 +62,8 @@ struct BitOR;
|
|||||||
* \param argv the array of input arguments
|
* \param argv the array of input arguments
|
||||||
*/
|
*/
|
||||||
inline void Init(int argc, char *argv[]);
|
inline void Init(int argc, char *argv[]);
|
||||||
/*!
|
/*!
|
||||||
* \brief finalizes the rabit engine, call this function after you finished with all the jobs
|
* \brief finalizes the rabit engine, call this function after you finished with all the jobs
|
||||||
*/
|
*/
|
||||||
inline void Finalize(void);
|
inline void Finalize(void);
|
||||||
/*! \brief gets rank of the current process */
|
/*! \brief gets rank of the current process */
|
||||||
@ -71,7 +77,7 @@ inline bool IsDistributed(void);
|
|||||||
inline std::string GetProcessorName(void);
|
inline std::string GetProcessorName(void);
|
||||||
/*!
|
/*!
|
||||||
* \brief prints the msg to the tracker,
|
* \brief prints the msg to the tracker,
|
||||||
* this function can be used to communicate progress information to
|
* this function can be used to communicate progress information to
|
||||||
* the user who monitors the tracker
|
* the user who monitors the tracker
|
||||||
* \param msg the message to be printed
|
* \param msg the message to be printed
|
||||||
*/
|
*/
|
||||||
@ -89,7 +95,7 @@ inline void TrackerPrintf(const char *fmt, ...);
|
|||||||
/*!
|
/*!
|
||||||
* \brief broadcasts a memory region to every node from the root
|
* \brief broadcasts a memory region to every node from the root
|
||||||
*
|
*
|
||||||
* Example: int a = 1; Broadcast(&a, sizeof(a), root);
|
* Example: int a = 1; Broadcast(&a, sizeof(a), root);
|
||||||
* \param sendrecv_data the pointer to the send/receive buffer,
|
* \param sendrecv_data the pointer to the send/receive buffer,
|
||||||
* \param size the data size
|
* \param size the data size
|
||||||
* \param root the process root
|
* \param root the process root
|
||||||
@ -113,7 +119,7 @@ inline void Broadcast(std::vector<DType> *sendrecv_data, int root);
|
|||||||
*/
|
*/
|
||||||
inline void Broadcast(std::string *sendrecv_data, int root);
|
inline void Broadcast(std::string *sendrecv_data, int root);
|
||||||
/*!
|
/*!
|
||||||
* \brief performs in-place Allreduce on sendrecvbuf
|
* \brief performs in-place Allreduce on sendrecvbuf
|
||||||
* this function is NOT thread-safe
|
* this function is NOT thread-safe
|
||||||
*
|
*
|
||||||
* Example Usage: the following code does an Allreduce and outputs the sum as the result
|
* Example Usage: the following code does an Allreduce and outputs the sum as the result
|
||||||
@ -126,8 +132,8 @@ inline void Broadcast(std::string *sendrecv_data, int root);
|
|||||||
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
||||||
* will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
|
* will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
|
||||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||||
* \param prepare_arg argument used to pass into the lazy preprocessing function
|
* \param prepare_arg argument used to pass into the lazy preprocessing function
|
||||||
* \tparam OP see namespace op, reduce operator
|
* \tparam OP see namespace op, reduce operator
|
||||||
* \tparam DType data type
|
* \tparam DType data type
|
||||||
*/
|
*/
|
||||||
template<typename OP, typename DType>
|
template<typename OP, typename DType>
|
||||||
@ -135,7 +141,7 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
|
|||||||
void (*prepare_fun)(void *arg) = NULL,
|
void (*prepare_fun)(void *arg) = NULL,
|
||||||
void *prepare_arg = NULL);
|
void *prepare_arg = NULL);
|
||||||
// C++11 support for lambda prepare function
|
// C++11 support for lambda prepare function
|
||||||
#if __cplusplus >= 201103L
|
#if DMLC_USE_CXX11
|
||||||
/*!
|
/*!
|
||||||
* \brief performs in-place Allreduce, on sendrecvbuf
|
* \brief performs in-place Allreduce, on sendrecvbuf
|
||||||
* with a prepare function specified by a lambda function
|
* with a prepare function specified by a lambda function
|
||||||
@ -154,7 +160,7 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
|
|||||||
* \param prepare_fun Lazy lambda preprocessing function, prepare_fun() will be invoked
|
* \param prepare_fun Lazy lambda preprocessing function, prepare_fun() will be invoked
|
||||||
* by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
|
* by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
|
||||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||||
* \tparam OP see namespace op, reduce operator
|
* \tparam OP see namespace op, reduce operator
|
||||||
* \tparam DType data type
|
* \tparam DType data type
|
||||||
*/
|
*/
|
||||||
template<typename OP, typename DType>
|
template<typename OP, typename DType>
|
||||||
@ -168,18 +174,18 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
|
|||||||
* is the same in every node
|
* is the same in every node
|
||||||
* \param local_model pointer to the local model that is specific to the current node/rank
|
* \param local_model pointer to the local model that is specific to the current node/rank
|
||||||
* this can be NULL when no local model is needed
|
* this can be NULL when no local model is needed
|
||||||
*
|
*
|
||||||
* \return the version number of the check point loaded
|
* \return the version number of the check point loaded
|
||||||
* if returned version == 0, this means no model has been CheckPointed
|
* if returned version == 0, this means no model has been CheckPointed
|
||||||
* the p_model is not touched, users should do the necessary initialization by themselves
|
* the p_model is not touched, users should do the necessary initialization by themselves
|
||||||
*
|
*
|
||||||
* Common usage example:
|
* Common usage example:
|
||||||
* int iter = rabit::LoadCheckPoint(&model);
|
* int iter = rabit::LoadCheckPoint(&model);
|
||||||
* if (iter == 0) model.InitParameters();
|
* if (iter == 0) model.InitParameters();
|
||||||
* for (i = iter; i < max_iter; ++i) {
|
* for (i = iter; i < max_iter; ++i) {
|
||||||
* do many things, include allreduce
|
* do many things, include allreduce
|
||||||
* rabit::CheckPoint(model);
|
* rabit::CheckPoint(model);
|
||||||
* }
|
* }
|
||||||
*
|
*
|
||||||
* \sa CheckPoint, VersionNumber
|
* \sa CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
@ -188,7 +194,7 @@ inline int LoadCheckPoint(Serializable *global_model,
|
|||||||
/*!
|
/*!
|
||||||
* \brief checkpoints the model, meaning a stage of execution has finished.
|
* \brief checkpoints the model, meaning a stage of execution has finished.
|
||||||
* every time we call check point, a version number will be increased by one
|
* every time we call check point, a version number will be increased by one
|
||||||
*
|
*
|
||||||
* \param global_model pointer to the globally shared model/state
|
* \param global_model pointer to the globally shared model/state
|
||||||
* when calling this function, the caller needs to guarantee that the global_model
|
* when calling this function, the caller needs to guarantee that the global_model
|
||||||
* is the same in every node
|
* is the same in every node
|
||||||
@ -204,16 +210,16 @@ inline void CheckPoint(const Serializable *global_model,
|
|||||||
/*!
|
/*!
|
||||||
* \brief This function can be used to replace CheckPoint for global_model only,
|
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||||
* when certain condition is met (see detailed explanation).
|
* when certain condition is met (see detailed explanation).
|
||||||
*
|
*
|
||||||
* This is a "lazy" checkpoint such that only the pointer to the global_model is
|
* This is a "lazy" checkpoint such that only the pointer to the global_model is
|
||||||
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
||||||
* The global_model must remain unchanged until the last call of Allreduce/Broadcast in the current version finishes.
|
* The global_model must remain unchanged until the last call of Allreduce/Broadcast in the current version finishes.
|
||||||
* In other words, the global_model model can be changed only between the last call of
|
* In other words, the global_model model can be changed only between the last call of
|
||||||
* Allreduce/Broadcast and LazyCheckPoint, both in the same version
|
* Allreduce/Broadcast and LazyCheckPoint, both in the same version
|
||||||
*
|
*
|
||||||
* For example, suppose the calling sequence is:
|
* For example, suppose the calling sequence is:
|
||||||
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint/(or can be CheckPoint)
|
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint/(or can be CheckPoint)
|
||||||
*
|
*
|
||||||
* Then the user MUST only change the global_model in code3.
|
* Then the user MUST only change the global_model in code3.
|
||||||
*
|
*
|
||||||
* The use of LazyCheckPoint instead of CheckPoint will improve the efficiency of the program.
|
* The use of LazyCheckPoint instead of CheckPoint will improve the efficiency of the program.
|
||||||
@ -235,36 +241,36 @@ namespace engine {
|
|||||||
class ReduceHandle;
|
class ReduceHandle;
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
/*!
|
/*!
|
||||||
* \brief template class to make customized reduce and all reduce easy
|
* \brief template class to make customized reduce and all reduce easy
|
||||||
* Do not use reducer directly in the function you call Finalize,
|
* Do not use reducer directly in the function you call Finalize,
|
||||||
* because the destructor can execute after Finalize
|
* because the destructor can execute after Finalize
|
||||||
* \tparam DType data type that to be reduced
|
* \tparam DType data type that to be reduced
|
||||||
* \tparam freduce the customized reduction function
|
* \tparam freduce the customized reduction function
|
||||||
* DType must be a struct, with no pointer
|
* DType must be a struct, with no pointer
|
||||||
*/
|
*/
|
||||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
|
||||||
class Reducer {
|
class Reducer {
|
||||||
public:
|
public:
|
||||||
Reducer(void);
|
Reducer(void);
|
||||||
/*!
|
/*!
|
||||||
* \brief customized in-place all reduce operation
|
* \brief customized in-place all reduce operation
|
||||||
* \param sendrecvbuf the in place send-recv buffer
|
* \param sendrecvbuf the in place send-recv buffer
|
||||||
* \param count number of elements to be reduced
|
* \param count number of elements to be reduced
|
||||||
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
||||||
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf.
|
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf.
|
||||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||||
* \param prepare_arg argument used to pass into the lazy preprocessing function
|
* \param prepare_arg argument used to pass into the lazy preprocessing function
|
||||||
*/
|
*/
|
||||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||||
void (*prepare_fun)(void *arg) = NULL,
|
void (*prepare_fun)(void *arg) = NULL,
|
||||||
void *prepare_arg = NULL);
|
void *prepare_arg = NULL);
|
||||||
#if __cplusplus >= 201103L
|
#if DMLC_USE_CXX11
|
||||||
/*!
|
/*!
|
||||||
* \brief customized in-place all reduce operation, with lambda function as preprocessor
|
* \brief customized in-place all reduce operation, with lambda function as preprocessor
|
||||||
* \param sendrecvbuf pointer to the array of objects to be reduced
|
* \param sendrecvbuf pointer to the array of objects to be reduced
|
||||||
* \param count number of elements to be reduced
|
* \param count number of elements to be reduced
|
||||||
* \param prepare_fun lambda function executed to prepare the data, if necessary
|
* \param prepare_fun lambda function executed to prepare the data, if necessary
|
||||||
*/
|
*/
|
||||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||||
std::function<void()> prepare_fun);
|
std::function<void()> prepare_fun);
|
||||||
#endif
|
#endif
|
||||||
@ -278,7 +284,7 @@ class Reducer {
|
|||||||
* this class defines complex reducer handles all the data structure that can be
|
* this class defines complex reducer handles all the data structure that can be
|
||||||
* serialized/deserialized into fixed size buffer
|
* serialized/deserialized into fixed size buffer
|
||||||
* Do not use reducer directly in the function you call Finalize, because the destructor can execute after Finalize
|
* Do not use reducer directly in the function you call Finalize, because the destructor can execute after Finalize
|
||||||
*
|
*
|
||||||
* \tparam DType data type that to be reduced, DType must contain the following functions:
|
* \tparam DType data type that to be reduced, DType must contain the following functions:
|
||||||
* \tparam freduce the customized reduction function
|
* \tparam freduce the customized reduction function
|
||||||
* (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &src, size_t max_nbyte)
|
* (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &src, size_t max_nbyte)
|
||||||
@ -288,7 +294,7 @@ class SerializeReducer {
|
|||||||
public:
|
public:
|
||||||
SerializeReducer(void);
|
SerializeReducer(void);
|
||||||
/*!
|
/*!
|
||||||
* \brief customized in-place all reduce operation
|
* \brief customized in-place all reduce operation
|
||||||
* \param sendrecvobj pointer to the array of objects to be reduced
|
* \param sendrecvobj pointer to the array of objects to be reduced
|
||||||
* \param max_nbyte maximum amount of memory needed to serialize each object
|
* \param max_nbyte maximum amount of memory needed to serialize each object
|
||||||
* this includes budget limit for intermediate and final result
|
* this includes budget limit for intermediate and final result
|
||||||
@ -296,14 +302,14 @@ class SerializeReducer {
|
|||||||
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
||||||
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf.
|
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf.
|
||||||
* If the result of Allreduce can be recovered directly, then the prepare_func will NOT be called
|
* If the result of Allreduce can be recovered directly, then the prepare_func will NOT be called
|
||||||
* \param prepare_arg argument used to pass into the lazy preprocessing function
|
* \param prepare_arg argument used to pass into the lazy preprocessing function
|
||||||
*/
|
*/
|
||||||
inline void Allreduce(DType *sendrecvobj,
|
inline void Allreduce(DType *sendrecvobj,
|
||||||
size_t max_nbyte, size_t count,
|
size_t max_nbyte, size_t count,
|
||||||
void (*prepare_fun)(void *arg) = NULL,
|
void (*prepare_fun)(void *arg) = NULL,
|
||||||
void *prepare_arg = NULL);
|
void *prepare_arg = NULL);
|
||||||
// C++11 support for lambda prepare function
|
// C++11 support for lambda prepare function
|
||||||
#if __cplusplus >= 201103L
|
#if DMLC_USE_CXX11
|
||||||
/*!
|
/*!
|
||||||
* \brief customized in-place all reduce operation, with lambda function as preprocessor
|
* \brief customized in-place all reduce operation, with lambda function as preprocessor
|
||||||
* \param sendrecvobj pointer to the array of objects to be reduced
|
* \param sendrecvobj pointer to the array of objects to be reduced
|
||||||
@ -311,7 +317,7 @@ class SerializeReducer {
|
|||||||
* this includes budget limit for intermediate and final result
|
* this includes budget limit for intermediate and final result
|
||||||
* \param count number of elements to be reduced
|
* \param count number of elements to be reduced
|
||||||
* \param prepare_fun lambda function executed to prepare the data, if necessary
|
* \param prepare_fun lambda function executed to prepare the data, if necessary
|
||||||
*/
|
*/
|
||||||
inline void Allreduce(DType *sendrecvobj,
|
inline void Allreduce(DType *sendrecvobj,
|
||||||
size_t max_nbyte, size_t count,
|
size_t max_nbyte, size_t count,
|
||||||
std::function<void()> prepare_fun);
|
std::function<void()> prepare_fun);
|
||||||
@ -326,4 +332,4 @@ class SerializeReducer {
|
|||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
// implementation of template functions
|
// implementation of template functions
|
||||||
#include "./rabit/rabit-inl.h"
|
#include "./rabit/rabit-inl.h"
|
||||||
#endif // RABIT_RABIT_H_
|
#endif // RABIT_RABIT_H_ // NOLINT(*)
|
||||||
|
|||||||
@ -4,8 +4,8 @@
|
|||||||
* \brief utilities with different serializable implementations
|
* \brief utilities with different serializable implementations
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#ifndef RABIT_UTILS_IO_H_
|
#ifndef RABIT_IO_H_
|
||||||
#define RABIT_UTILS_IO_H_
|
#define RABIT_IO_H_
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
@ -51,6 +51,7 @@ struct MemoryFixSizeBuffer : public SeekStream {
|
|||||||
virtual bool AtEnd(void) const {
|
virtual bool AtEnd(void) const {
|
||||||
return curr_ptr_ == buffer_size_;
|
return curr_ptr_ == buffer_size_;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/*! \brief in memory buffer */
|
/*! \brief in memory buffer */
|
||||||
char *p_buffer_;
|
char *p_buffer_;
|
||||||
@ -93,6 +94,7 @@ struct MemoryBufferStream : public SeekStream {
|
|||||||
virtual bool AtEnd(void) const {
|
virtual bool AtEnd(void) const {
|
||||||
return curr_ptr_ == p_buffer_->length();
|
return curr_ptr_ == p_buffer_->length();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/*! \brief in memory buffer */
|
/*! \brief in memory buffer */
|
||||||
std::string *p_buffer_;
|
std::string *p_buffer_;
|
||||||
@ -101,4 +103,4 @@ struct MemoryBufferStream : public SeekStream {
|
|||||||
}; // class MemoryBufferStream
|
}; // class MemoryBufferStream
|
||||||
} // namespace utils
|
} // namespace utils
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
#endif // RABIT_UTILS_IO_H_
|
#endif // RABIT_IO_H_
|
||||||
|
|||||||
@ -1,12 +1,15 @@
|
|||||||
/*!
|
/*!
|
||||||
|
* Copyright by Contributors
|
||||||
* \file rabit-inl.h
|
* \file rabit-inl.h
|
||||||
* \brief implementation of inline template function for rabit interface
|
* \brief implementation of inline template function for rabit interface
|
||||||
*
|
*
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#ifndef RABIT_RABIT_INL_H
|
#ifndef RABIT_RABIT_INL_H_
|
||||||
#define RABIT_RABIT_INL_H
|
#define RABIT_RABIT_INL_H_
|
||||||
// use engine for implementation
|
// use engine for implementation
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
#include "./io.h"
|
#include "./io.h"
|
||||||
#include "./utils.h"
|
#include "./utils.h"
|
||||||
#include "../rabit.h"
|
#include "../rabit.h"
|
||||||
@ -30,15 +33,15 @@ inline DataType GetType<int>(void) {
|
|||||||
return kInt;
|
return kInt;
|
||||||
}
|
}
|
||||||
template<>
|
template<>
|
||||||
inline DataType GetType<unsigned int>(void) {
|
inline DataType GetType<unsigned int>(void) { // NOLINT(*)
|
||||||
return kUInt;
|
return kUInt;
|
||||||
}
|
}
|
||||||
template<>
|
template<>
|
||||||
inline DataType GetType<long>(void) {
|
inline DataType GetType<long>(void) { // NOLINT(*)
|
||||||
return kLong;
|
return kLong;
|
||||||
}
|
}
|
||||||
template<>
|
template<>
|
||||||
inline DataType GetType<unsigned long>(void) {
|
inline DataType GetType<unsigned long>(void) { // NOLINT(*)
|
||||||
return kULong;
|
return kULong;
|
||||||
}
|
}
|
||||||
template<>
|
template<>
|
||||||
@ -50,54 +53,54 @@ inline DataType GetType<double>(void) {
|
|||||||
return kDouble;
|
return kDouble;
|
||||||
}
|
}
|
||||||
template<>
|
template<>
|
||||||
inline DataType GetType<long long>(void) {
|
inline DataType GetType<long long>(void) { // NOLINT(*)
|
||||||
return kLongLong;
|
return kLongLong;
|
||||||
}
|
}
|
||||||
template<>
|
template<>
|
||||||
inline DataType GetType<unsigned long long>(void) {
|
inline DataType GetType<unsigned long long>(void) { // NOLINT(*)
|
||||||
return kULongLong;
|
return kULongLong;
|
||||||
}
|
}
|
||||||
} // namespace mpi
|
} // namespace mpi
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
|
|
||||||
namespace op {
|
namespace op {
|
||||||
struct Max {
|
struct Max {
|
||||||
const static engine::mpi::OpType kType = engine::mpi::kMax;
|
static const engine::mpi::OpType kType = engine::mpi::kMax;
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
inline static void Reduce(DType &dst, const DType &src) {
|
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
|
||||||
if (dst < src) dst = src;
|
if (dst < src) dst = src;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
struct Min {
|
struct Min {
|
||||||
const static engine::mpi::OpType kType = engine::mpi::kMin;
|
static const engine::mpi::OpType kType = engine::mpi::kMin;
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
inline static void Reduce(DType &dst, const DType &src) {
|
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
|
||||||
if (dst > src) dst = src;
|
if (dst > src) dst = src;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
struct Sum {
|
struct Sum {
|
||||||
const static engine::mpi::OpType kType = engine::mpi::kSum;
|
static const engine::mpi::OpType kType = engine::mpi::kSum;
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
inline static void Reduce(DType &dst, const DType &src) {
|
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
|
||||||
dst += src;
|
dst += src;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
struct BitOR {
|
struct BitOR {
|
||||||
const static engine::mpi::OpType kType = engine::mpi::kBitwiseOR;
|
static const engine::mpi::OpType kType = engine::mpi::kBitwiseOR;
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
inline static void Reduce(DType &dst, const DType &src) {
|
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
|
||||||
dst |= src;
|
dst |= src;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
template<typename OP, typename DType>
|
template<typename OP, typename DType>
|
||||||
inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
|
inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
|
||||||
const DType *src = (const DType*)src_;
|
const DType *src = (const DType*)src_;
|
||||||
DType *dst = (DType*)dst_;
|
DType *dst = (DType*)dst_; // NOLINT(*)
|
||||||
for (int i = 0; i < len; ++i) {
|
for (int i = 0; i < len; ++i) {
|
||||||
OP::Reduce(dst[i], src[i]);
|
OP::Reduce(dst[i], src[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
|
||||||
// intialize the rabit engine
|
// intialize the rabit engine
|
||||||
inline void Init(int argc, char *argv[]) {
|
inline void Init(int argc, char *argv[]) {
|
||||||
@ -152,23 +155,23 @@ inline void Broadcast(std::string *sendrecv_data, int root) {
|
|||||||
// perform inplace Allreduce
|
// perform inplace Allreduce
|
||||||
template<typename OP, typename DType>
|
template<typename OP, typename DType>
|
||||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||||
void (*prepare_fun)(void *arg),
|
void (*prepare_fun)(void *arg),
|
||||||
void *prepare_arg) {
|
void *prepare_arg) {
|
||||||
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>,
|
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
|
||||||
engine::mpi::GetType<DType>(), OP::kType, prepare_fun, prepare_arg);
|
engine::mpi::GetType<DType>(), OP::kType, prepare_fun, prepare_arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
// C++11 support for lambda prepare function
|
// C++11 support for lambda prepare function
|
||||||
#if __cplusplus >= 201103L
|
#if DMLC_USE_CXX11
|
||||||
inline void InvokeLambda_(void *fun) {
|
inline void InvokeLambda_(void *fun) {
|
||||||
(*static_cast<std::function<void()>*>(fun))();
|
(*static_cast<std::function<void()>*>(fun))();
|
||||||
}
|
}
|
||||||
template<typename OP, typename DType>
|
template<typename OP, typename DType>
|
||||||
inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> prepare_fun) {
|
inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> prepare_fun) {
|
||||||
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>,
|
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
|
||||||
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda_, &prepare_fun);
|
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda_, &prepare_fun);
|
||||||
}
|
}
|
||||||
#endif // C++11
|
#endif // C++11
|
||||||
|
|
||||||
// print message to the tracker
|
// print message to the tracker
|
||||||
inline void TrackerPrint(const std::string &msg) {
|
inline void TrackerPrint(const std::string &msg) {
|
||||||
@ -223,15 +226,16 @@ inline void ReducerSafe_(const void *src_, void *dst_, int len_, const MPI::Data
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// function to perform reduction for Reducer
|
// function to perform reduction for Reducer
|
||||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
|
||||||
inline void ReducerAlign_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
inline void ReducerAlign_(const void *src_, void *dst_,
|
||||||
|
int len_, const MPI::Datatype &dtype) {
|
||||||
const DType *psrc = reinterpret_cast<const DType*>(src_);
|
const DType *psrc = reinterpret_cast<const DType*>(src_);
|
||||||
DType *pdst = reinterpret_cast<DType*>(dst_);
|
DType *pdst = reinterpret_cast<DType*>(dst_);
|
||||||
for (int i = 0; i < len_; ++i) {
|
for (int i = 0; i < len_; ++i) {
|
||||||
freduce(pdst[i], psrc[i]);
|
freduce(pdst[i], psrc[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
|
||||||
inline Reducer<DType, freduce>::Reducer(void) {
|
inline Reducer<DType, freduce>::Reducer(void) {
|
||||||
// it is safe to directly use handle for aligned data types
|
// it is safe to directly use handle for aligned data types
|
||||||
if (sizeof(DType) == 8 || sizeof(DType) == 4 || sizeof(DType) == 1) {
|
if (sizeof(DType) == 8 || sizeof(DType) == 4 || sizeof(DType) == 1) {
|
||||||
@ -240,7 +244,7 @@ inline Reducer<DType, freduce>::Reducer(void) {
|
|||||||
this->handle_.Init(ReducerSafe_<DType, freduce>, sizeof(DType));
|
this->handle_.Init(ReducerSafe_<DType, freduce>, sizeof(DType));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
|
||||||
inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
|
inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
|
||||||
void (*prepare_fun)(void *arg),
|
void (*prepare_fun)(void *arg),
|
||||||
void *prepare_arg) {
|
void *prepare_arg) {
|
||||||
@ -248,13 +252,14 @@ inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
|
|||||||
}
|
}
|
||||||
// function to perform reduction for SerializeReducer
|
// function to perform reduction for SerializeReducer
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
inline void SerializeReducerFunc_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
inline void SerializeReducerFunc_(const void *src_, void *dst_,
|
||||||
|
int len_, const MPI::Datatype &dtype) {
|
||||||
int nbytes = engine::ReduceHandle::TypeSize(dtype);
|
int nbytes = engine::ReduceHandle::TypeSize(dtype);
|
||||||
// temp space
|
// temp space
|
||||||
DType tsrc, tdst;
|
DType tsrc, tdst;
|
||||||
for (int i = 0; i < len_; ++i) {
|
for (int i = 0; i < len_; ++i) {
|
||||||
utils::MemoryFixSizeBuffer fsrc((char*)(src_) + i * nbytes, nbytes);
|
utils::MemoryFixSizeBuffer fsrc((char*)(src_) + i * nbytes, nbytes); // NOLINT(*)
|
||||||
utils::MemoryFixSizeBuffer fdst((char*)(dst_) + i * nbytes, nbytes);
|
utils::MemoryFixSizeBuffer fdst((char*)(dst_) + i * nbytes, nbytes); // NOLINT(*)
|
||||||
tsrc.Load(fsrc);
|
tsrc.Load(fsrc);
|
||||||
tdst.Load(fdst);
|
tdst.Load(fdst);
|
||||||
// govern const check
|
// govern const check
|
||||||
@ -296,8 +301,8 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
|||||||
// setup closure
|
// setup closure
|
||||||
SerializeReduceClosure<DType> c;
|
SerializeReduceClosure<DType> c;
|
||||||
c.sendrecvobj = sendrecvobj; c.max_nbyte = max_nbyte; c.count = count;
|
c.sendrecvobj = sendrecvobj; c.max_nbyte = max_nbyte; c.count = count;
|
||||||
c.prepare_fun = prepare_fun; c.prepare_arg = prepare_arg; c.p_buffer = &buffer_;
|
c.prepare_fun = prepare_fun; c.prepare_arg = prepare_arg; c.p_buffer = &buffer_;
|
||||||
// invoke here
|
// invoke here
|
||||||
handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count,
|
handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count,
|
||||||
SerializeReduceClosure<DType>::Invoke, &c);
|
SerializeReduceClosure<DType>::Invoke, &c);
|
||||||
for (size_t i = 0; i < count; ++i) {
|
for (size_t i = 0; i < count; ++i) {
|
||||||
@ -306,8 +311,8 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#if __cplusplus >= 201103L
|
#if DMLC_USE_CXX11
|
||||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)g
|
||||||
inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
|
inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
|
||||||
std::function<void()> prepare_fun) {
|
std::function<void()> prepare_fun) {
|
||||||
this->Allreduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun);
|
this->Allreduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun);
|
||||||
@ -320,4 +325,4 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
#endif
|
#endif // RABIT_RABIT_INL_H_
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
|
* Copyright by Contributors
|
||||||
* \file timer.h
|
* \file timer.h
|
||||||
* \brief This file defines the utils for timing
|
* \brief This file defines the utils for timing
|
||||||
* \author Tianqi Chen, Nacho, Tianyi
|
* \author Tianqi Chen, Nacho, Tianyi
|
||||||
@ -18,7 +19,6 @@ namespace utils {
|
|||||||
* \brief return time in seconds, not cross platform, avoid to use this in most places
|
* \brief return time in seconds, not cross platform, avoid to use this in most places
|
||||||
*/
|
*/
|
||||||
inline double GetTime(void) {
|
inline double GetTime(void) {
|
||||||
// TODO: use c++11 chrono when c++11 was available
|
|
||||||
#ifdef __MACH__
|
#ifdef __MACH__
|
||||||
clock_serv_t cclock;
|
clock_serv_t cclock;
|
||||||
mach_timespec_t mts;
|
mach_timespec_t mts;
|
||||||
@ -32,7 +32,6 @@ inline double GetTime(void) {
|
|||||||
utils::Check(clock_gettime(CLOCK_REALTIME, &ts) == 0, "failed to get time");
|
utils::Check(clock_gettime(CLOCK_REALTIME, &ts) == 0, "failed to get time");
|
||||||
return static_cast<double>(ts.tv_sec) + static_cast<double>(ts.tv_nsec) * 1e-9;
|
return static_cast<double>(ts.tv_sec) + static_cast<double>(ts.tv_nsec) * 1e-9;
|
||||||
#else
|
#else
|
||||||
// TODO: add MSVC macro, and MSVC timer
|
|
||||||
return static_cast<double>(time(NULL));
|
return static_cast<double>(time(NULL));
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -27,7 +27,7 @@
|
|||||||
#else
|
#else
|
||||||
#ifdef _FILE_OFFSET_BITS
|
#ifdef _FILE_OFFSET_BITS
|
||||||
#if _FILE_OFFSET_BITS == 32
|
#if _FILE_OFFSET_BITS == 32
|
||||||
#pragma message ("Warning: FILE OFFSET BITS defined to be 32 bit")
|
#pragma message("Warning: FILE OFFSET BITS defined to be 32 bit")
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -59,17 +59,17 @@ namespace utils {
|
|||||||
const int kPrintBuffer = 1 << 12;
|
const int kPrintBuffer = 1 << 12;
|
||||||
|
|
||||||
#ifndef RABIT_CUSTOMIZE_MSG_
|
#ifndef RABIT_CUSTOMIZE_MSG_
|
||||||
/*!
|
/*!
|
||||||
* \brief handling of Assert error, caused by inappropriate input
|
* \brief handling of Assert error, caused by inappropriate input
|
||||||
* \param msg error message
|
* \param msg error message
|
||||||
*/
|
*/
|
||||||
inline void HandleAssertError(const char *msg) {
|
inline void HandleAssertError(const char *msg) {
|
||||||
fprintf(stderr, "AssertError:%s\n", msg);
|
fprintf(stderr, "AssertError:%s\n", msg);
|
||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief handling of Check error, caused by inappropriate input
|
* \brief handling of Check error, caused by inappropriate input
|
||||||
* \param msg error message
|
* \param msg error message
|
||||||
*/
|
*/
|
||||||
inline void HandleCheckError(const char *msg) {
|
inline void HandleCheckError(const char *msg) {
|
||||||
fprintf(stderr, "%s\n", msg);
|
fprintf(stderr, "%s\n", msg);
|
||||||
@ -163,7 +163,7 @@ inline std::FILE *FopenCheck(const char *fname, const char *flag) {
|
|||||||
// easy utils that can be directly accessed in xgboost
|
// easy utils that can be directly accessed in xgboost
|
||||||
/*! \brief get the beginning address of a vector */
|
/*! \brief get the beginning address of a vector */
|
||||||
template<typename T>
|
template<typename T>
|
||||||
inline T *BeginPtr(std::vector<T> &vec) {
|
inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*)
|
||||||
if (vec.size() == 0) {
|
if (vec.size() == 0) {
|
||||||
return NULL;
|
return NULL;
|
||||||
} else {
|
} else {
|
||||||
@ -172,14 +172,14 @@ inline T *BeginPtr(std::vector<T> &vec) {
|
|||||||
}
|
}
|
||||||
/*! \brief get the beginning address of a vector */
|
/*! \brief get the beginning address of a vector */
|
||||||
template<typename T>
|
template<typename T>
|
||||||
inline const T *BeginPtr(const std::vector<T> &vec) {
|
inline const T *BeginPtr(const std::vector<T> &vec) { // NOLINT(*)
|
||||||
if (vec.size() == 0) {
|
if (vec.size() == 0) {
|
||||||
return NULL;
|
return NULL;
|
||||||
} else {
|
} else {
|
||||||
return &vec[0];
|
return &vec[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
inline char* BeginPtr(std::string &str) {
|
inline char* BeginPtr(std::string &str) { // NOLINT(*)
|
||||||
if (str.length() == 0) return NULL;
|
if (str.length() == 0) return NULL;
|
||||||
return &str[0];
|
return &str[0];
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,8 +4,8 @@
|
|||||||
* \brief defines serializable interface of rabit
|
* \brief defines serializable interface of rabit
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#ifndef RABIT_RABIT_SERIALIZABLE_H_
|
#ifndef RABIT_SERIALIZABLE_H_
|
||||||
#define RABIT_RABIT_SERIALIZABLE_H_
|
#define RABIT_SERIALIZABLE_H_
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "./rabit/utils.h"
|
#include "./rabit/utils.h"
|
||||||
@ -13,15 +13,15 @@
|
|||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
/*!
|
/*!
|
||||||
* \brief defines stream used in rabit
|
* \brief defines stream used in rabit
|
||||||
* see definition of Stream in dmlc/io.h
|
* see definition of Stream in dmlc/io.h
|
||||||
*/
|
*/
|
||||||
typedef dmlc::Stream Stream;
|
typedef dmlc::Stream Stream;
|
||||||
/*!
|
/*!
|
||||||
* \brief defines serializable objects used in rabit
|
* \brief defines serializable objects used in rabit
|
||||||
* see definition of Serializable in dmlc/io.h
|
* see definition of Serializable in dmlc/io.h
|
||||||
*/
|
*/
|
||||||
typedef dmlc::Serializable Serializable;
|
typedef dmlc::Serializable Serializable;
|
||||||
|
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
#endif // RABIT_RABIT_SERIALIZABLE_H_
|
#endif // RABIT_SERIALIZABLE_H_
|
||||||
|
|||||||
8
scripts/travis_runtest.sh
Executable file
8
scripts/travis_runtest.sh
Executable file
@ -0,0 +1,8 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
make -f test.mk model_recover_10_10k || exit -1
|
||||||
|
make -f test.mk model_recover_10_10k_die_same || exit -1
|
||||||
|
make -f test.mk local_recover_10_10k || exit -1
|
||||||
|
make -f test.mk pylocal_recover_10_10k || exit -1
|
||||||
|
make -f test.mk lazy_recover_10_10k_die_hard || exit -1
|
||||||
|
make -f test.mk lazy_recover_10_10k_die_same || exit -1
|
||||||
|
make -f test.mk ringallreduce_10_10k || exit -1
|
||||||
22
scripts/travis_script.sh
Executable file
22
scripts/travis_script.sh
Executable file
@ -0,0 +1,22 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# main script of travis
|
||||||
|
if [ ${TASK} == "lint" ]; then
|
||||||
|
make lint || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${TASK} == "doc" ]; then
|
||||||
|
make doc 2>log.txt
|
||||||
|
(cat log.txt|grep warning) && exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${TASK} == "build" ]; then
|
||||||
|
make all || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${TASK} == "test" ]; then
|
||||||
|
cd test
|
||||||
|
make all || exit -1
|
||||||
|
./travis_runtest.sh || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
@ -94,7 +94,8 @@ void AllreduceBase::Init(void) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (dmlc_role != "worker") {
|
if (dmlc_role != "worker") {
|
||||||
fprintf(stderr, "Rabit Module currently only work with dmlc worker, quit this program by exit 0\n");
|
fprintf(stderr, "Rabit Module currently only work with dmlc worker"\
|
||||||
|
", quit this program by exit 0\n");
|
||||||
exit(0);
|
exit(0);
|
||||||
}
|
}
|
||||||
// clear the setting before start reconnection
|
// clear the setting before start reconnection
|
||||||
@ -134,7 +135,7 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
|
|||||||
// util to parse data with unit suffix
|
// util to parse data with unit suffix
|
||||||
inline size_t ParseUnit(const char *name, const char *val) {
|
inline size_t ParseUnit(const char *name, const char *val) {
|
||||||
char unit;
|
char unit;
|
||||||
unsigned long amt;
|
unsigned long amt; // NOLINT(*)
|
||||||
int n = sscanf(val, "%lu%c", &amt, &unit);
|
int n = sscanf(val, "%lu%c", &amt, &unit);
|
||||||
size_t amount = amt;
|
size_t amount = amt;
|
||||||
if (n == 2) {
|
if (n == 2) {
|
||||||
@ -154,7 +155,7 @@ inline size_t ParseUnit(const char *name, const char *val) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief set parameters to the engine
|
* \brief set parameters to the engine
|
||||||
* \param name parameter name
|
* \param name parameter name
|
||||||
* \param val parameter value
|
* \param val parameter value
|
||||||
*/
|
*/
|
||||||
@ -258,7 +259,7 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
} else {
|
} else {
|
||||||
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
|
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int ngood = static_cast<int>(good_link.size());
|
int ngood = static_cast<int>(good_link.size());
|
||||||
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
|
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
|
||||||
"ReConnectLink failure 5");
|
"ReConnectLink failure 5");
|
||||||
@ -359,7 +360,7 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
* The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
|
* The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
|
||||||
* It only means the current node get the correct result of Allreduce.
|
* It only means the current node get the correct result of Allreduce.
|
||||||
* However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
|
* However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
|
||||||
*
|
*
|
||||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
* \param type_nbytes the unit number of bytes the type have
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
* \param count number of elements to be reduced
|
* \param count number of elements to be reduced
|
||||||
@ -440,7 +441,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
|||||||
selecter.WatchRead(links[i].sock);
|
selecter.WatchRead(links[i].sock);
|
||||||
}
|
}
|
||||||
// size_write <= size_read
|
// size_write <= size_read
|
||||||
if (links[i].size_write != total_size){
|
if (links[i].size_write != total_size) {
|
||||||
if (links[i].size_write < size_down_in) {
|
if (links[i].size_write < size_down_in) {
|
||||||
selecter.WatchWrite(links[i].sock);
|
selecter.WatchWrite(links[i].sock);
|
||||||
}
|
}
|
||||||
@ -477,7 +478,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
|||||||
size_t max_reduce = total_size;
|
size_t max_reduce = total_size;
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (i != parent_index) {
|
if (i != parent_index) {
|
||||||
max_reduce= std::min(max_reduce, links[i].size_read);
|
max_reduce = std::min(max_reduce, links[i].size_read);
|
||||||
utils::Assert(buffer_size == 0 || buffer_size == links[i].buffer_size,
|
utils::Assert(buffer_size == 0 || buffer_size == links[i].buffer_size,
|
||||||
"buffer size inconsistent");
|
"buffer size inconsistent");
|
||||||
buffer_size = links[i].buffer_size;
|
buffer_size = links[i].buffer_size;
|
||||||
@ -525,7 +526,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
|||||||
ssize_t len = links[parent_index].sock.
|
ssize_t len = links[parent_index].sock.
|
||||||
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
|
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
|
||||||
if (len == 0) {
|
if (len == 0) {
|
||||||
links[parent_index].sock.Close();
|
links[parent_index].sock.Close();
|
||||||
return ReportError(&links[parent_index], kRecvZeroLen);
|
return ReportError(&links[parent_index], kRecvZeroLen);
|
||||||
}
|
}
|
||||||
if (len != -1) {
|
if (len != -1) {
|
||||||
@ -670,7 +671,7 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
|
|||||||
size_t slice_begin,
|
size_t slice_begin,
|
||||||
size_t slice_end,
|
size_t slice_end,
|
||||||
size_t size_prev_slice) {
|
size_t size_prev_slice) {
|
||||||
// read from next link and send to prev one
|
// read from next link and send to prev one
|
||||||
LinkRecord &prev = *ring_prev, &next = *ring_next;
|
LinkRecord &prev = *ring_prev, &next = *ring_next;
|
||||||
// need to reply on special rank structure
|
// need to reply on special rank structure
|
||||||
utils::Assert(next.rank == (rank + 1) % world_size &&
|
utils::Assert(next.rank == (rank + 1) % world_size &&
|
||||||
@ -678,11 +679,11 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
|
|||||||
"need to assume rank structure");
|
"need to assume rank structure");
|
||||||
// send recv buffer
|
// send recv buffer
|
||||||
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
|
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||||
const size_t stop_read = total_size + slice_begin;
|
const size_t stop_read = total_size + slice_begin;
|
||||||
const size_t stop_write = total_size + slice_begin - size_prev_slice;
|
const size_t stop_write = total_size + slice_begin - size_prev_slice;
|
||||||
size_t write_ptr = slice_begin;
|
size_t write_ptr = slice_begin;
|
||||||
size_t read_ptr = slice_end;
|
size_t read_ptr = slice_end;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
// select helper
|
// select helper
|
||||||
bool finished = true;
|
bool finished = true;
|
||||||
@ -733,7 +734,7 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
|
|||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail,
|
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail,
|
||||||
* and will return the cause of failure
|
* and will return the cause of failure
|
||||||
*
|
*
|
||||||
* Ring-based algorithm
|
* Ring-based algorithm
|
||||||
*
|
*
|
||||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
@ -748,7 +749,7 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
|
|||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer) {
|
ReduceFunction reducer) {
|
||||||
// read from next link and send to prev one
|
// read from next link and send to prev one
|
||||||
LinkRecord &prev = *ring_prev, &next = *ring_next;
|
LinkRecord &prev = *ring_prev, &next = *ring_next;
|
||||||
// need to reply on special rank structure
|
// need to reply on special rank structure
|
||||||
utils::Assert(next.rank == (rank + 1) % world_size &&
|
utils::Assert(next.rank == (rank + 1) % world_size &&
|
||||||
@ -757,7 +758,7 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
|
|||||||
// total size of message
|
// total size of message
|
||||||
const size_t total_size = type_nbytes * count;
|
const size_t total_size = type_nbytes * count;
|
||||||
size_t n = static_cast<size_t>(world_size);
|
size_t n = static_cast<size_t>(world_size);
|
||||||
size_t step = (count + n - 1) / n;
|
size_t step = (count + n - 1) / n;
|
||||||
size_t r = static_cast<size_t>(next.rank);
|
size_t r = static_cast<size_t>(next.rank);
|
||||||
size_t write_ptr = std::min(r * step, count) * type_nbytes;
|
size_t write_ptr = std::min(r * step, count) * type_nbytes;
|
||||||
size_t read_ptr = std::min((r + 1) * step, count) * type_nbytes;
|
size_t read_ptr = std::min((r + 1) * step, count) * type_nbytes;
|
||||||
@ -830,7 +831,7 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
|
|||||||
if (ret != kSuccess) return ReportError(&prev, ret);
|
if (ret != kSuccess) return ReportError(&prev, ret);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -857,7 +858,7 @@ AllreduceBase::TryAllreduceRing(void *sendrecvbuf_,
|
|||||||
size_t end = std::min((rank + 1) * step, count) * type_nbytes;
|
size_t end = std::min((rank + 1) * step, count) * type_nbytes;
|
||||||
// previous rank
|
// previous rank
|
||||||
int prank = ring_prev->rank;
|
int prank = ring_prev->rank;
|
||||||
// get rank of previous
|
// get rank of previous
|
||||||
return TryAllgatherRing
|
return TryAllgatherRing
|
||||||
(sendrecvbuf_, type_nbytes * count,
|
(sendrecvbuf_, type_nbytes * count,
|
||||||
begin, end,
|
begin, end,
|
||||||
|
|||||||
@ -42,7 +42,7 @@ class AllreduceBase : public IEngine {
|
|||||||
// shutdown the engine
|
// shutdown the engine
|
||||||
virtual void Shutdown(void);
|
virtual void Shutdown(void);
|
||||||
/*!
|
/*!
|
||||||
* \brief set parameters to the engine
|
* \brief set parameters to the engine
|
||||||
* \param name parameter name
|
* \param name parameter name
|
||||||
* \param val parameter value
|
* \param val parameter value
|
||||||
*/
|
*/
|
||||||
@ -72,7 +72,7 @@ class AllreduceBase : public IEngine {
|
|||||||
return host_uri;
|
return host_uri;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf
|
* \brief perform in-place allreduce, on sendrecvbuf
|
||||||
* this function is NOT thread-safe
|
* this function is NOT thread-safe
|
||||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
* \param type_nbytes the unit number of bytes the type have
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
@ -82,7 +82,7 @@ class AllreduceBase : public IEngine {
|
|||||||
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||||
*/
|
*/
|
||||||
virtual void Allreduce(void *sendrecvbuf_,
|
virtual void Allreduce(void *sendrecvbuf_,
|
||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
@ -117,14 +117,14 @@ class AllreduceBase : public IEngine {
|
|||||||
* \return the version number of check point loaded
|
* \return the version number of check point loaded
|
||||||
* if returned version == 0, this means no model has been CheckPointed
|
* if returned version == 0, this means no model has been CheckPointed
|
||||||
* the p_model is not touched, user should do necessary initialization by themselves
|
* the p_model is not touched, user should do necessary initialization by themselves
|
||||||
*
|
*
|
||||||
* Common usage example:
|
* Common usage example:
|
||||||
* int iter = rabit::LoadCheckPoint(&model);
|
* int iter = rabit::LoadCheckPoint(&model);
|
||||||
* if (iter == 0) model.InitParameters();
|
* if (iter == 0) model.InitParameters();
|
||||||
* for (i = iter; i < max_iter; ++i) {
|
* for (i = iter; i < max_iter; ++i) {
|
||||||
* do many things, include allreduce
|
* do many things, include allreduce
|
||||||
* rabit::CheckPoint(model);
|
* rabit::CheckPoint(model);
|
||||||
* }
|
* }
|
||||||
*
|
*
|
||||||
* \sa CheckPoint, VersionNumber
|
* \sa CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
@ -135,7 +135,7 @@ class AllreduceBase : public IEngine {
|
|||||||
/*!
|
/*!
|
||||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||||
* every time we call check point, there is a version number which will increase by one
|
* every time we call check point, there is a version number which will increase by one
|
||||||
*
|
*
|
||||||
* \param global_model pointer to the globally shared model/state
|
* \param global_model pointer to the globally shared model/state
|
||||||
* when calling this function, the caller need to gauranttees that global_model
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
* is the same in all nodes
|
* is the same in all nodes
|
||||||
@ -155,16 +155,16 @@ class AllreduceBase : public IEngine {
|
|||||||
/*!
|
/*!
|
||||||
* \brief This function can be used to replace CheckPoint for global_model only,
|
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||||
* when certain condition is met(see detailed expplaination).
|
* when certain condition is met(see detailed expplaination).
|
||||||
*
|
*
|
||||||
* This is a "lazy" checkpoint such that only the pointer to global_model is
|
* This is a "lazy" checkpoint such that only the pointer to global_model is
|
||||||
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
||||||
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
|
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
|
||||||
* In another words, global_model model can be changed only between last call of
|
* In another words, global_model model can be changed only between last call of
|
||||||
* Allreduce/Broadcast and LazyCheckPoint in current version
|
* Allreduce/Broadcast and LazyCheckPoint in current version
|
||||||
*
|
*
|
||||||
* For example, suppose the calling sequence is:
|
* For example, suppose the calling sequence is:
|
||||||
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
|
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
|
||||||
*
|
*
|
||||||
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
|
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
|
||||||
* improve efficiency of the program.
|
* improve efficiency of the program.
|
||||||
* \param global_model pointer to the globally shared model/state
|
* \param global_model pointer to the globally shared model/state
|
||||||
@ -191,8 +191,8 @@ class AllreduceBase : public IEngine {
|
|||||||
virtual void InitAfterException(void) {
|
virtual void InitAfterException(void) {
|
||||||
utils::Error("InitAfterException: not implemented");
|
utils::Error("InitAfterException: not implemented");
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief report current status to the job tracker
|
* \brief report current status to the job tracker
|
||||||
* depending on the job tracker we are in
|
* depending on the job tracker we are in
|
||||||
*/
|
*/
|
||||||
inline void ReportStatus(void) const {
|
inline void ReportStatus(void) const {
|
||||||
@ -213,7 +213,7 @@ class AllreduceBase : public IEngine {
|
|||||||
kRecvZeroLen,
|
kRecvZeroLen,
|
||||||
/*! \brief a neighbor node go down, the connection is dropped */
|
/*! \brief a neighbor node go down, the connection is dropped */
|
||||||
kSockError,
|
kSockError,
|
||||||
/*!
|
/*!
|
||||||
* \brief another node which is not my neighbor go down,
|
* \brief another node which is not my neighbor go down,
|
||||||
* get Out-of-Band exception notification from my neighbor
|
* get Out-of-Band exception notification from my neighbor
|
||||||
*/
|
*/
|
||||||
@ -225,7 +225,7 @@ class AllreduceBase : public IEngine {
|
|||||||
ReturnTypeEnum value;
|
ReturnTypeEnum value;
|
||||||
// constructor
|
// constructor
|
||||||
ReturnType() {}
|
ReturnType() {}
|
||||||
ReturnType(ReturnTypeEnum value) : value(value){}
|
ReturnType(ReturnTypeEnum value) : value(value) {} // NOLINT(*)
|
||||||
inline bool operator==(const ReturnTypeEnum &v) const {
|
inline bool operator==(const ReturnTypeEnum &v) const {
|
||||||
return value == v;
|
return value == v;
|
||||||
}
|
}
|
||||||
@ -235,11 +235,11 @@ class AllreduceBase : public IEngine {
|
|||||||
};
|
};
|
||||||
/*! \brief translate errno to return type */
|
/*! \brief translate errno to return type */
|
||||||
inline static ReturnType Errno2Return() {
|
inline static ReturnType Errno2Return() {
|
||||||
int errsv = utils::Socket::GetLastError();
|
int errsv = utils::Socket::GetLastError();
|
||||||
if (errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == 0) return kSuccess;
|
if (errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == 0) return kSuccess;
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
if (errsv == WSAEWOULDBLOCK) return kSuccess;
|
if (errsv == WSAEWOULDBLOCK) return kSuccess;
|
||||||
if (errsv == WSAECONNRESET) return kConnReset;
|
if (errsv == WSAECONNRESET) return kConnReset;
|
||||||
#endif
|
#endif
|
||||||
if (errsv == ECONNRESET) return kConnReset;
|
if (errsv == ECONNRESET) return kConnReset;
|
||||||
return kSockError;
|
return kSockError;
|
||||||
@ -260,7 +260,7 @@ class AllreduceBase : public IEngine {
|
|||||||
// buffer size, in bytes
|
// buffer size, in bytes
|
||||||
size_t buffer_size;
|
size_t buffer_size;
|
||||||
// constructor
|
// constructor
|
||||||
LinkRecord(void)
|
LinkRecord(void)
|
||||||
: buffer_head(NULL), buffer_size(0) {
|
: buffer_head(NULL), buffer_size(0) {
|
||||||
}
|
}
|
||||||
// initialize buffer
|
// initialize buffer
|
||||||
@ -377,7 +377,7 @@ class AllreduceBase : public IEngine {
|
|||||||
* The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
|
* The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
|
||||||
* It only means the current node get the correct result of Allreduce.
|
* It only means the current node get the correct result of Allreduce.
|
||||||
* However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
|
* However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
|
||||||
*
|
*
|
||||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
* \param type_nbytes the unit number of bytes the type have
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
* \param count number of elements to be reduced
|
* \param count number of elements to be reduced
|
||||||
@ -397,7 +397,7 @@ class AllreduceBase : public IEngine {
|
|||||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
* \sa ReturnType
|
* \sa ReturnType
|
||||||
*/
|
*/
|
||||||
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
|
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
|
||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf,
|
* \brief perform in-place allreduce, on sendrecvbuf,
|
||||||
* this function implements tree-shape reduction
|
* this function implements tree-shape reduction
|
||||||
@ -433,14 +433,14 @@ class AllreduceBase : public IEngine {
|
|||||||
size_t size_prev_slice);
|
size_t size_prev_slice);
|
||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, reduce on the sendrecvbuf,
|
* \brief perform in-place allreduce, reduce on the sendrecvbuf,
|
||||||
*
|
*
|
||||||
* after the function, node k get k-th segment of the reduction result
|
* after the function, node k get k-th segment of the reduction result
|
||||||
* the k-th segment is defined by [k * step, min((k + 1) * step,count) )
|
* the k-th segment is defined by [k * step, min((k + 1) * step,count) )
|
||||||
* where step = ceil(count / world_size)
|
* where step = ceil(count / world_size)
|
||||||
*
|
*
|
||||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
* \param type_nbytes the unit number of bytes the type have
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
* \param count number of elements to be reduced
|
* \param count number of elements to be reduced
|
||||||
* \param reducer reduce function
|
* \param reducer reduce function
|
||||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||||
* \sa ReturnType, TryAllreduce
|
* \sa ReturnType, TryAllreduce
|
||||||
@ -465,7 +465,7 @@ class AllreduceBase : public IEngine {
|
|||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer);
|
ReduceFunction reducer);
|
||||||
/*!
|
/*!
|
||||||
* \brief function used to report error when a link goes wrong
|
* \brief function used to report error when a link goes wrong
|
||||||
* \param link the pointer to the link who causes the error
|
* \param link the pointer to the link who causes the error
|
||||||
* \param err the error type
|
* \param err the error type
|
||||||
*/
|
*/
|
||||||
@ -522,4 +522,4 @@ class AllreduceBase : public IEngine {
|
|||||||
};
|
};
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
#endif // RABIT_ALLREDUCE_BASE_H
|
#endif // RABIT_ALLREDUCE_BASE_H_
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
/*!
|
/*!
|
||||||
|
* Copyright by Contributors
|
||||||
* \file allreduce_mock.h
|
* \file allreduce_mock.h
|
||||||
* \brief Mock test module of AllReduce engine,
|
* \brief Mock test module of AllReduce engine,
|
||||||
* insert failures in certain call point, to test if the engine is robust to failure
|
* insert failures in certain call point, to test if the engine is robust to failure
|
||||||
*
|
*
|
||||||
* \author Ignacio Cano, Tianqi Chen
|
* \author Ignacio Cano, Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#ifndef RABIT_ALLREDUCE_MOCK_H_
|
#ifndef RABIT_ALLREDUCE_MOCK_H_
|
||||||
@ -68,7 +69,7 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
DummySerializer dum;
|
DummySerializer dum;
|
||||||
ComboSerializer com(global_model, local_model);
|
ComboSerializer com(global_model, local_model);
|
||||||
return AllreduceRobust::LoadCheckPoint(&dum, &com);
|
return AllreduceRobust::LoadCheckPoint(&dum, &com);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
virtual void CheckPoint(const Serializable *global_model,
|
virtual void CheckPoint(const Serializable *global_model,
|
||||||
const Serializable *local_model) {
|
const Serializable *local_model) {
|
||||||
@ -100,6 +101,7 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
|
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
|
||||||
AllreduceRobust::LazyCheckPoint(global_model);
|
AllreduceRobust::LazyCheckPoint(global_model);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// force checkpoint to local
|
// force checkpoint to local
|
||||||
int force_local;
|
int force_local;
|
||||||
@ -108,7 +110,7 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
// sum of allreduce
|
// sum of allreduce
|
||||||
double tsum_allreduce;
|
double tsum_allreduce;
|
||||||
double time_checkpoint;
|
double time_checkpoint;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct DummySerializer : public Serializable {
|
struct DummySerializer : public Serializable {
|
||||||
virtual void Load(Stream *fi) {
|
virtual void Load(Stream *fi) {
|
||||||
@ -126,7 +128,7 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
}
|
}
|
||||||
ComboSerializer(const Serializable *lhs, const Serializable *rhs)
|
ComboSerializer(const Serializable *lhs, const Serializable *rhs)
|
||||||
: lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) {
|
: lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) {
|
||||||
}
|
}
|
||||||
virtual void Load(Stream *fi) {
|
virtual void Load(Stream *fi) {
|
||||||
if (lhs != NULL) lhs->Load(fi);
|
if (lhs != NULL) lhs->Load(fi);
|
||||||
if (rhs != NULL) rhs->Load(fi);
|
if (rhs != NULL) rhs->Load(fi);
|
||||||
@ -143,10 +145,10 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
int seqno;
|
int seqno;
|
||||||
int ntrial;
|
int ntrial;
|
||||||
MockKey(void) {}
|
MockKey(void) {}
|
||||||
MockKey(int rank, int version, int seqno, int ntrial)
|
MockKey(int rank, int version, int seqno, int ntrial)
|
||||||
: rank(rank), version(version), seqno(seqno), ntrial(ntrial) {}
|
: rank(rank), version(version), seqno(seqno), ntrial(ntrial) {}
|
||||||
inline bool operator==(const MockKey &b) const {
|
inline bool operator==(const MockKey &b) const {
|
||||||
return rank == b.rank &&
|
return rank == b.rank &&
|
||||||
version == b.version &&
|
version == b.version &&
|
||||||
seqno == b.seqno &&
|
seqno == b.seqno &&
|
||||||
ntrial == b.ntrial;
|
ntrial == b.ntrial;
|
||||||
@ -173,4 +175,4 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
};
|
};
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
#endif // RABIT_ALLREDUCE_MOCK_H_
|
#endif // RABIT_ALLREDUCE_MOCK_H_
|
||||||
|
|||||||
@ -2,17 +2,17 @@
|
|||||||
* Copyright (c) 2014 by Contributors
|
* Copyright (c) 2014 by Contributors
|
||||||
* \file allreduce_robust-inl.h
|
* \file allreduce_robust-inl.h
|
||||||
* \brief implementation of inline template function in AllreduceRobust
|
* \brief implementation of inline template function in AllreduceRobust
|
||||||
*
|
*
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#ifndef RABIT_ENGINE_ROBUST_INL_H_
|
#ifndef RABIT_ALLREDUCE_ROBUST_INL_H_
|
||||||
#define RABIT_ENGINE_ROBUST_INL_H_
|
#define RABIT_ALLREDUCE_ROBUST_INL_H_
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
/*!
|
/*!
|
||||||
* \brief run message passing algorithm on the allreduce tree
|
* \brief run message passing algorithm on the allreduce tree
|
||||||
* the result is edge message stored in p_edge_in and p_edge_out
|
* the result is edge message stored in p_edge_in and p_edge_out
|
||||||
* \param node_value the value associated with current node
|
* \param node_value the value associated with current node
|
||||||
* \param p_edge_in used to store input message from each of the edge
|
* \param p_edge_in used to store input message from each of the edge
|
||||||
@ -35,7 +35,7 @@ inline AllreduceRobust::ReturnType
|
|||||||
AllreduceRobust::MsgPassing(const NodeType &node_value,
|
AllreduceRobust::MsgPassing(const NodeType &node_value,
|
||||||
std::vector<EdgeType> *p_edge_in,
|
std::vector<EdgeType> *p_edge_in,
|
||||||
std::vector<EdgeType> *p_edge_out,
|
std::vector<EdgeType> *p_edge_out,
|
||||||
EdgeType (*func)
|
EdgeType(*func)
|
||||||
(const NodeType &node_value,
|
(const NodeType &node_value,
|
||||||
const std::vector<EdgeType> &edge_in,
|
const std::vector<EdgeType> &edge_in,
|
||||||
size_t out_index)) {
|
size_t out_index)) {
|
||||||
@ -80,8 +80,16 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
|
|||||||
selecter.WatchRead(links[i].sock);
|
selecter.WatchRead(links[i].sock);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case 1: if (i == parent_index) selecter.WatchWrite(links[i].sock); break;
|
case 1:
|
||||||
case 2: if (i == parent_index) selecter.WatchRead(links[i].sock); break;
|
if (i == parent_index) {
|
||||||
|
selecter.WatchWrite(links[i].sock);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
if (i == parent_index) {
|
||||||
|
selecter.WatchRead(links[i].sock);
|
||||||
|
}
|
||||||
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
|
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
|
||||||
selecter.WatchWrite(links[i].sock);
|
selecter.WatchWrite(links[i].sock);
|
||||||
@ -158,4 +166,4 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
|
|||||||
}
|
}
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
#endif // RABIT_ENGINE_ROBUST_INL_H_
|
#endif // RABIT_ALLREDUCE_ROBUST_INL_H_
|
||||||
|
|||||||
@ -27,7 +27,7 @@ AllreduceRobust::AllreduceRobust(void) {
|
|||||||
result_buffer_round = 1;
|
result_buffer_round = 1;
|
||||||
global_lazycheck = NULL;
|
global_lazycheck = NULL;
|
||||||
use_local_model = -1;
|
use_local_model = -1;
|
||||||
recover_counter = 0;
|
recover_counter = 0;
|
||||||
env_vars.push_back("rabit_global_replica");
|
env_vars.push_back("rabit_global_replica");
|
||||||
env_vars.push_back("rabit_local_replica");
|
env_vars.push_back("rabit_local_replica");
|
||||||
}
|
}
|
||||||
@ -49,7 +49,7 @@ void AllreduceRobust::Shutdown(void) {
|
|||||||
AllreduceBase::Shutdown();
|
AllreduceBase::Shutdown();
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief set parameters to the engine
|
* \brief set parameters to the engine
|
||||||
* \param name parameter name
|
* \param name parameter name
|
||||||
* \param val parameter value
|
* \param val parameter value
|
||||||
*/
|
*/
|
||||||
@ -61,7 +61,7 @@ void AllreduceRobust::SetParam(const char *name, const char *val) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf
|
* \brief perform in-place allreduce, on sendrecvbuf
|
||||||
* this function is NOT thread-safe
|
* this function is NOT thread-safe
|
||||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
* \param type_nbytes the unit number of bytes the type have
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
@ -147,14 +147,14 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
|
|||||||
* \return the version number of check point loaded
|
* \return the version number of check point loaded
|
||||||
* if returned version == 0, this means no model has been CheckPointed
|
* if returned version == 0, this means no model has been CheckPointed
|
||||||
* the p_model is not touched, user should do necessary initialization by themselves
|
* the p_model is not touched, user should do necessary initialization by themselves
|
||||||
*
|
*
|
||||||
* Common usage example:
|
* Common usage example:
|
||||||
* int iter = rabit::LoadCheckPoint(&model);
|
* int iter = rabit::LoadCheckPoint(&model);
|
||||||
* if (iter == 0) model.InitParameters();
|
* if (iter == 0) model.InitParameters();
|
||||||
* for (i = iter; i < max_iter; ++i) {
|
* for (i = iter; i < max_iter; ++i) {
|
||||||
* do many things, include allreduce
|
* do many things, include allreduce
|
||||||
* rabit::CheckPoint(model);
|
* rabit::CheckPoint(model);
|
||||||
* }
|
* }
|
||||||
*
|
*
|
||||||
* \sa CheckPoint, VersionNumber
|
* \sa CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
@ -208,7 +208,7 @@ int AllreduceRobust::LoadCheckPoint(Serializable *global_model,
|
|||||||
* \brief internal consistency check function,
|
* \brief internal consistency check function,
|
||||||
* use check to ensure user always call CheckPoint/LoadCheckPoint
|
* use check to ensure user always call CheckPoint/LoadCheckPoint
|
||||||
* with or without local but not both, this function will set the approperiate settings
|
* with or without local but not both, this function will set the approperiate settings
|
||||||
* in the first call of LoadCheckPoint/CheckPoint
|
* in the first call of LoadCheckPoint/CheckPoint
|
||||||
*
|
*
|
||||||
* \param with_local whether the user calls CheckPoint with local model
|
* \param with_local whether the user calls CheckPoint with local model
|
||||||
*/
|
*/
|
||||||
@ -224,14 +224,14 @@ void AllreduceRobust::LocalModelCheck(bool with_local) {
|
|||||||
num_local_replica = 0;
|
num_local_replica = 0;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
utils::Check(use_local_model == int(with_local),
|
utils::Check(use_local_model == static_cast<int>(with_local),
|
||||||
"Can only call Checkpoint/LoadCheckPoint always with"\
|
"Can only call Checkpoint/LoadCheckPoint always with"\
|
||||||
"or without local_model, but not mixed case");
|
"or without local_model, but not mixed case");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief internal implementation of checkpoint, support both lazy and normal way
|
* \brief internal implementation of checkpoint, support both lazy and normal way
|
||||||
*
|
*
|
||||||
* \param global_model pointer to the globally shared model/state
|
* \param global_model pointer to the globally shared model/state
|
||||||
* when calling this function, the caller need to gauranttees that global_model
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
* is the same in all nodes
|
* is the same in all nodes
|
||||||
@ -423,7 +423,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
|||||||
* recover links according to the error type reported
|
* recover links according to the error type reported
|
||||||
* if there is no error, return true
|
* if there is no error, return true
|
||||||
* \param err_type the type of error happening in the system
|
* \param err_type the type of error happening in the system
|
||||||
* \return true if err_type is kSuccess, false otherwise
|
* \return true if err_type is kSuccess, false otherwise
|
||||||
*/
|
*/
|
||||||
bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
|
bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
|
||||||
if (err_type == kSuccess) return true;
|
if (err_type == kSuccess) return true;
|
||||||
@ -488,7 +488,7 @@ ShortestDist(const std::pair<bool, size_t> &node_value,
|
|||||||
* \brief message passing function, used to decide the
|
* \brief message passing function, used to decide the
|
||||||
* data request from each edge, whether need to request data from certain edge
|
* data request from each edge, whether need to request data from certain edge
|
||||||
* \param node_value a pair of request_data and best_link
|
* \param node_value a pair of request_data and best_link
|
||||||
* request_data stores whether current node need to request data
|
* request_data stores whether current node need to request data
|
||||||
* best_link gives the best edge index to fetch the data
|
* best_link gives the best edge index to fetch the data
|
||||||
* \param req_in the data request from incoming edges
|
* \param req_in the data request from incoming edges
|
||||||
* \param out_index the edge index of output link
|
* \param out_index the edge index of output link
|
||||||
@ -524,7 +524,7 @@ inline char DataRequest(const std::pair<bool, int> &node_value,
|
|||||||
*
|
*
|
||||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||||
* \sa ReturnType
|
* \sa ReturnType
|
||||||
*/
|
*/
|
||||||
AllreduceRobust::ReturnType
|
AllreduceRobust::ReturnType
|
||||||
AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role,
|
AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role,
|
||||||
size_t *p_size,
|
size_t *p_size,
|
||||||
@ -586,7 +586,7 @@ AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role,
|
|||||||
*
|
*
|
||||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||||
* \sa ReturnType, TryDecideRouting
|
* \sa ReturnType, TryDecideRouting
|
||||||
*/
|
*/
|
||||||
AllreduceRobust::ReturnType
|
AllreduceRobust::ReturnType
|
||||||
AllreduceRobust::TryRecoverData(RecoverType role,
|
AllreduceRobust::TryRecoverData(RecoverType role,
|
||||||
void *sendrecvbuf_,
|
void *sendrecvbuf_,
|
||||||
@ -644,7 +644,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
|||||||
if (role == kRequestData) {
|
if (role == kRequestData) {
|
||||||
const int pid = recv_link;
|
const int pid = recv_link;
|
||||||
if (selecter.CheckRead(links[pid].sock)) {
|
if (selecter.CheckRead(links[pid].sock)) {
|
||||||
ReturnType ret = links[pid].ReadToArray(sendrecvbuf_, size);
|
ReturnType ret = links[pid].ReadToArray(sendrecvbuf_, size);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
return ReportError(&links[pid], ret);
|
return ReportError(&links[pid], ret);
|
||||||
}
|
}
|
||||||
@ -823,10 +823,10 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
|
|||||||
* \param buf the buffer to store the result
|
* \param buf the buffer to store the result
|
||||||
* \param size the total size of the buffer
|
* \param size the total size of the buffer
|
||||||
* \param flag flag information about the action \sa ActionSummary
|
* \param flag flag information about the action \sa ActionSummary
|
||||||
* \param seqno sequence number of the action, if it is special action with flag set,
|
* \param seqno sequence number of the action, if it is special action with flag set,
|
||||||
* seqno needs to be set to ActionSummary::kSpecialOp
|
* seqno needs to be set to ActionSummary::kSpecialOp
|
||||||
*
|
*
|
||||||
* \return if this function can return true or false
|
* \return if this function can return true or false
|
||||||
* - true means buf already set to the
|
* - true means buf already set to the
|
||||||
* result by recovering procedure, the action is complete, no further action is needed
|
* result by recovering procedure, the action is complete, no further action is needed
|
||||||
* - false means this is the lastest action that has not yet been executed, need to execute the action
|
* - false means this is the lastest action that has not yet been executed, need to execute the action
|
||||||
@ -907,7 +907,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
|
|||||||
* plus replication of states in previous num_local_replica hops in the ring
|
* plus replication of states in previous num_local_replica hops in the ring
|
||||||
*
|
*
|
||||||
* The input parameters must contain the valid local states available in current nodes,
|
* The input parameters must contain the valid local states available in current nodes,
|
||||||
* This function try ist best to "complete" the missing parts of local_rptr and local_chkpt
|
* This function try ist best to "complete" the missing parts of local_rptr and local_chkpt
|
||||||
* If there is sufficient information in the ring, when the function returns, local_chkpt will
|
* If there is sufficient information in the ring, when the function returns, local_chkpt will
|
||||||
* contain num_local_replica + 1 checkpoints (including the chkpt of this node)
|
* contain num_local_replica + 1 checkpoints (including the chkpt of this node)
|
||||||
* If there is no sufficient information in the ring, this function the number of checkpoints
|
* If there is no sufficient information in the ring, this function the number of checkpoints
|
||||||
|
|||||||
@ -5,7 +5,7 @@
|
|||||||
* using TCP non-block socket and tree-shape reduction.
|
* using TCP non-block socket and tree-shape reduction.
|
||||||
*
|
*
|
||||||
* This implementation considers the failure of nodes
|
* This implementation considers the failure of nodes
|
||||||
*
|
*
|
||||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||||
*/
|
*/
|
||||||
#ifndef RABIT_ALLREDUCE_ROBUST_H_
|
#ifndef RABIT_ALLREDUCE_ROBUST_H_
|
||||||
@ -28,13 +28,13 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
/*! \brief shutdown the engine */
|
/*! \brief shutdown the engine */
|
||||||
virtual void Shutdown(void);
|
virtual void Shutdown(void);
|
||||||
/*!
|
/*!
|
||||||
* \brief set parameters to the engine
|
* \brief set parameters to the engine
|
||||||
* \param name parameter name
|
* \param name parameter name
|
||||||
* \param val parameter value
|
* \param val parameter value
|
||||||
*/
|
*/
|
||||||
virtual void SetParam(const char *name, const char *val);
|
virtual void SetParam(const char *name, const char *val);
|
||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf
|
* \brief perform in-place allreduce, on sendrecvbuf
|
||||||
* this function is NOT thread-safe
|
* this function is NOT thread-safe
|
||||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
* \param type_nbytes the unit number of bytes the type have
|
* \param type_nbytes the unit number of bytes the type have
|
||||||
@ -69,14 +69,14 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
* \return the version number of check point loaded
|
* \return the version number of check point loaded
|
||||||
* if returned version == 0, this means no model has been CheckPointed
|
* if returned version == 0, this means no model has been CheckPointed
|
||||||
* the p_model is not touched, user should do necessary initialization by themselves
|
* the p_model is not touched, user should do necessary initialization by themselves
|
||||||
*
|
*
|
||||||
* Common usage example:
|
* Common usage example:
|
||||||
* int iter = rabit::LoadCheckPoint(&model);
|
* int iter = rabit::LoadCheckPoint(&model);
|
||||||
* if (iter == 0) model.InitParameters();
|
* if (iter == 0) model.InitParameters();
|
||||||
* for (i = iter; i < max_iter; ++i) {
|
* for (i = iter; i < max_iter; ++i) {
|
||||||
* do many things, include allreduce
|
* do many things, include allreduce
|
||||||
* rabit::CheckPoint(model);
|
* rabit::CheckPoint(model);
|
||||||
* }
|
* }
|
||||||
*
|
*
|
||||||
* \sa CheckPoint, VersionNumber
|
* \sa CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
@ -85,7 +85,7 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
/*!
|
/*!
|
||||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||||
* every time we call check point, there is a version number which will increase by one
|
* every time we call check point, there is a version number which will increase by one
|
||||||
*
|
*
|
||||||
* \param global_model pointer to the globally shared model/state
|
* \param global_model pointer to the globally shared model/state
|
||||||
* when calling this function, the caller need to gauranttees that global_model
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
* is the same in all nodes
|
* is the same in all nodes
|
||||||
@ -105,16 +105,16 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
/*!
|
/*!
|
||||||
* \brief This function can be used to replace CheckPoint for global_model only,
|
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||||
* when certain condition is met(see detailed expplaination).
|
* when certain condition is met(see detailed expplaination).
|
||||||
*
|
*
|
||||||
* This is a "lazy" checkpoint such that only the pointer to global_model is
|
* This is a "lazy" checkpoint such that only the pointer to global_model is
|
||||||
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
||||||
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
|
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
|
||||||
* In another words, global_model model can be changed only between last call of
|
* In another words, global_model model can be changed only between last call of
|
||||||
* Allreduce/Broadcast and LazyCheckPoint in current version
|
* Allreduce/Broadcast and LazyCheckPoint in current version
|
||||||
*
|
*
|
||||||
* For example, suppose the calling sequence is:
|
* For example, suppose the calling sequence is:
|
||||||
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
|
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
|
||||||
*
|
*
|
||||||
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
|
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
|
||||||
* improve efficiency of the program.
|
* improve efficiency of the program.
|
||||||
* \param global_model pointer to the globally shared model/state
|
* \param global_model pointer to the globally shared model/state
|
||||||
@ -287,6 +287,7 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
if (seqno_.size() == 0) return -1;
|
if (seqno_.size() == 0) return -1;
|
||||||
return seqno_.back();
|
return seqno_.back();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// sequence number of each
|
// sequence number of each
|
||||||
std::vector<int> seqno_;
|
std::vector<int> seqno_;
|
||||||
@ -301,14 +302,14 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
* \brief internal consistency check function,
|
* \brief internal consistency check function,
|
||||||
* use check to ensure user always call CheckPoint/LoadCheckPoint
|
* use check to ensure user always call CheckPoint/LoadCheckPoint
|
||||||
* with or without local but not both, this function will set the approperiate settings
|
* with or without local but not both, this function will set the approperiate settings
|
||||||
* in the first call of LoadCheckPoint/CheckPoint
|
* in the first call of LoadCheckPoint/CheckPoint
|
||||||
*
|
*
|
||||||
* \param with_local whether the user calls CheckPoint with local model
|
* \param with_local whether the user calls CheckPoint with local model
|
||||||
*/
|
*/
|
||||||
void LocalModelCheck(bool with_local);
|
void LocalModelCheck(bool with_local);
|
||||||
/*!
|
/*!
|
||||||
* \brief internal implementation of checkpoint, support both lazy and normal way
|
* \brief internal implementation of checkpoint, support both lazy and normal way
|
||||||
*
|
*
|
||||||
* \param global_model pointer to the globally shared model/state
|
* \param global_model pointer to the globally shared model/state
|
||||||
* when calling this function, the caller need to gauranttees that global_model
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
* is the same in all nodes
|
* is the same in all nodes
|
||||||
@ -326,10 +327,10 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
* after this function finishes, all the messages received and sent
|
* after this function finishes, all the messages received and sent
|
||||||
* before in all live links are discarded,
|
* before in all live links are discarded,
|
||||||
* This allows us to get a fresh start after error has happened
|
* This allows us to get a fresh start after error has happened
|
||||||
*
|
*
|
||||||
* TODO(tqchen): this function is not yet functioning was not used by engine,
|
* TODO(tqchen): this function is not yet functioning was not used by engine,
|
||||||
* simple resetlink and reconnect strategy is used
|
* simple resetlink and reconnect strategy is used
|
||||||
*
|
*
|
||||||
* \return this function can return kSuccess or kSockError
|
* \return this function can return kSuccess or kSockError
|
||||||
* when kSockError is returned, it simply means there are bad sockets in the links,
|
* when kSockError is returned, it simply means there are bad sockets in the links,
|
||||||
* and some link recovery proceduer is needed
|
* and some link recovery proceduer is needed
|
||||||
@ -340,7 +341,7 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
* recover links according to the error type reported
|
* recover links according to the error type reported
|
||||||
* if there is no error, return true
|
* if there is no error, return true
|
||||||
* \param err_type the type of error happening in the system
|
* \param err_type the type of error happening in the system
|
||||||
* \return true if err_type is kSuccess, false otherwise
|
* \return true if err_type is kSuccess, false otherwise
|
||||||
*/
|
*/
|
||||||
bool CheckAndRecover(ReturnType err_type);
|
bool CheckAndRecover(ReturnType err_type);
|
||||||
/*!
|
/*!
|
||||||
@ -355,7 +356,7 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
* \param seqno sequence number of the action, if it is special action with flag set,
|
* \param seqno sequence number of the action, if it is special action with flag set,
|
||||||
* seqno needs to be set to ActionSummary::kSpecialOp
|
* seqno needs to be set to ActionSummary::kSpecialOp
|
||||||
*
|
*
|
||||||
* \return if this function can return true or false
|
* \return if this function can return true or false
|
||||||
* - true means buf already set to the
|
* - true means buf already set to the
|
||||||
* result by recovering procedure, the action is complete, no further action is needed
|
* result by recovering procedure, the action is complete, no further action is needed
|
||||||
* - false means this is the lastest action that has not yet been executed, need to execute the action
|
* - false means this is the lastest action that has not yet been executed, need to execute the action
|
||||||
@ -364,7 +365,7 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
int seqno = ActionSummary::kSpecialOp);
|
int seqno = ActionSummary::kSpecialOp);
|
||||||
/*!
|
/*!
|
||||||
* \brief try to load check point
|
* \brief try to load check point
|
||||||
*
|
*
|
||||||
* This is a collaborative function called by all nodes
|
* This is a collaborative function called by all nodes
|
||||||
* only the nodes with requester set to true really needs to load the check point
|
* only the nodes with requester set to true really needs to load the check point
|
||||||
* other nodes acts as collaborative roles to complete this request
|
* other nodes acts as collaborative roles to complete this request
|
||||||
@ -395,7 +396,7 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
* \param p_size used to store the size of the message, for node in state kHaveData,
|
* \param p_size used to store the size of the message, for node in state kHaveData,
|
||||||
* this size must be set correctly before calling the function
|
* this size must be set correctly before calling the function
|
||||||
* for others, this surves as output parameter
|
* for others, this surves as output parameter
|
||||||
|
|
||||||
* \param p_recvlink used to store the link current node should recv data from, if necessary
|
* \param p_recvlink used to store the link current node should recv data from, if necessary
|
||||||
* this can be -1, which means current node have the data
|
* this can be -1, which means current node have the data
|
||||||
* \param p_req_in used to store the resulting vector, indicating which link we should send the data to
|
* \param p_req_in used to store the resulting vector, indicating which link we should send the data to
|
||||||
@ -432,7 +433,7 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
* plus replication of states in previous num_local_replica hops in the ring
|
* plus replication of states in previous num_local_replica hops in the ring
|
||||||
*
|
*
|
||||||
* The input parameters must contain the valid local states available in current nodes,
|
* The input parameters must contain the valid local states available in current nodes,
|
||||||
* This function try ist best to "complete" the missing parts of local_rptr and local_chkpt
|
* This function try ist best to "complete" the missing parts of local_rptr and local_chkpt
|
||||||
* If there is sufficient information in the ring, when the function returns, local_chkpt will
|
* If there is sufficient information in the ring, when the function returns, local_chkpt will
|
||||||
* contain num_local_replica + 1 checkpoints (including the chkpt of this node)
|
* contain num_local_replica + 1 checkpoints (including the chkpt of this node)
|
||||||
* If there is no sufficient information in the ring, this function the number of checkpoints
|
* If there is no sufficient information in the ring, this function the number of checkpoints
|
||||||
@ -487,7 +488,7 @@ o * the input state must exactly one saved state(local state of current node)
|
|||||||
LinkRecord *read_link,
|
LinkRecord *read_link,
|
||||||
LinkRecord *write_link);
|
LinkRecord *write_link);
|
||||||
/*!
|
/*!
|
||||||
* \brief run message passing algorithm on the allreduce tree
|
* \brief run message passing algorithm on the allreduce tree
|
||||||
* the result is edge message stored in p_edge_in and p_edge_out
|
* the result is edge message stored in p_edge_in and p_edge_out
|
||||||
* \param node_value the value associated with current node
|
* \param node_value the value associated with current node
|
||||||
* \param p_edge_in used to store input message from each of the edge
|
* \param p_edge_in used to store input message from each of the edge
|
||||||
@ -509,7 +510,7 @@ o * the input state must exactly one saved state(local state of current node)
|
|||||||
inline ReturnType MsgPassing(const NodeType &node_value,
|
inline ReturnType MsgPassing(const NodeType &node_value,
|
||||||
std::vector<EdgeType> *p_edge_in,
|
std::vector<EdgeType> *p_edge_in,
|
||||||
std::vector<EdgeType> *p_edge_out,
|
std::vector<EdgeType> *p_edge_out,
|
||||||
EdgeType (*func)
|
EdgeType(*func)
|
||||||
(const NodeType &node_value,
|
(const NodeType &node_value,
|
||||||
const std::vector<EdgeType> &edge_in,
|
const std::vector<EdgeType> &edge_in,
|
||||||
size_t out_index));
|
size_t out_index));
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
* \file engine.cc
|
* \file engine.cc
|
||||||
* \brief this file governs which implementation of engine we are actually using
|
* \brief this file governs which implementation of engine we are actually using
|
||||||
* provides an singleton of engine interface
|
* provides an singleton of engine interface
|
||||||
*
|
*
|
||||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||||
*/
|
*/
|
||||||
#define _CRT_SECURE_NO_WARNINGS
|
#define _CRT_SECURE_NO_WARNINGS
|
||||||
@ -60,7 +60,7 @@ void Allreduce_(void *sendrecvbuf,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// code for reduce handle
|
// code for reduce handle
|
||||||
ReduceHandle::ReduceHandle(void)
|
ReduceHandle::ReduceHandle(void)
|
||||||
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
|
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
|
||||||
}
|
}
|
||||||
ReduceHandle::~ReduceHandle(void) {}
|
ReduceHandle::~ReduceHandle(void) {}
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
* \file engine_mpi.cc
|
* \file engine_mpi.cc
|
||||||
* \brief this file gives an implementation of engine interface using MPI,
|
* \brief this file gives an implementation of engine interface using MPI,
|
||||||
* this will allow rabit program to run with MPI, but do not comes with fault tolerant
|
* this will allow rabit program to run with MPI, but do not comes with fault tolerant
|
||||||
*
|
*
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#define _CRT_SECURE_NO_WARNINGS
|
#define _CRT_SECURE_NO_WARNINGS
|
||||||
@ -143,7 +143,7 @@ void Allreduce_(void *sendrecvbuf,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// code for reduce handle
|
// code for reduce handle
|
||||||
ReduceHandle::ReduceHandle(void)
|
ReduceHandle::ReduceHandle(void)
|
||||||
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
|
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
|
||||||
}
|
}
|
||||||
ReduceHandle::~ReduceHandle(void) {
|
ReduceHandle::~ReduceHandle(void) {
|
||||||
@ -166,7 +166,7 @@ void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
|||||||
if (type_nbytes != 0) {
|
if (type_nbytes != 0) {
|
||||||
MPI::Datatype *dtype = new MPI::Datatype();
|
MPI::Datatype *dtype = new MPI::Datatype();
|
||||||
if (type_nbytes % 8 == 0) {
|
if (type_nbytes % 8 == 0) {
|
||||||
*dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long));
|
*dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long)); // NOLINT(*)
|
||||||
} else if (type_nbytes % 4 == 0) {
|
} else if (type_nbytes % 4 == 0) {
|
||||||
*dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int));
|
*dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int));
|
||||||
} else {
|
} else {
|
||||||
@ -195,7 +195,7 @@ void ReduceHandle::Allreduce(void *sendrecvbuf,
|
|||||||
dtype->Free();
|
dtype->Free();
|
||||||
}
|
}
|
||||||
if (type_nbytes % 8 == 0) {
|
if (type_nbytes % 8 == 0) {
|
||||||
*dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long));
|
*dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long)); // NOLINT(*)
|
||||||
} else if (type_nbytes % 4 == 0) {
|
} else if (type_nbytes % 4 == 0) {
|
||||||
*dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int));
|
*dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int));
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
59
src/socket.h
59
src/socket.h
@ -51,7 +51,7 @@ struct SockAddr {
|
|||||||
utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
|
utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
|
||||||
return std::string(buf.c_str());
|
return std::string(buf.c_str());
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief set the address
|
* \brief set the address
|
||||||
* \param url the url of the address
|
* \param url the url of the address
|
||||||
* \param port the port of address
|
* \param port the port of address
|
||||||
@ -83,7 +83,7 @@ struct SockAddr {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief base class containing common operations of TCP and UDP sockets
|
* \brief base class containing common operations of TCP and UDP sockets
|
||||||
*/
|
*/
|
||||||
class Socket {
|
class Socket {
|
||||||
@ -95,7 +95,7 @@ class Socket {
|
|||||||
return sockfd;
|
return sockfd;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \return last error of socket operation
|
* \return last error of socket operation
|
||||||
*/
|
*/
|
||||||
inline static int GetLastError(void) {
|
inline static int GetLastError(void) {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
@ -106,7 +106,7 @@ class Socket {
|
|||||||
}
|
}
|
||||||
/*! \return whether last error was would block */
|
/*! \return whether last error was would block */
|
||||||
inline static bool LastErrorWouldBlock(void) {
|
inline static bool LastErrorWouldBlock(void) {
|
||||||
int errsv = GetLastError();
|
int errsv = GetLastError();
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
return errsv == WSAEWOULDBLOCK;
|
return errsv == WSAEWOULDBLOCK;
|
||||||
#else
|
#else
|
||||||
@ -129,15 +129,15 @@ class Socket {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief shutdown the socket module after use, all sockets need to be closed
|
* \brief shutdown the socket module after use, all sockets need to be closed
|
||||||
*/
|
*/
|
||||||
inline static void Finalize(void) {
|
inline static void Finalize(void) {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
WSACleanup();
|
WSACleanup();
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief set this socket to use non-blocking mode
|
* \brief set this socket to use non-blocking mode
|
||||||
* \param non_block whether set it to be non-block, if it is false
|
* \param non_block whether set it to be non-block, if it is false
|
||||||
* it will set it back to block mode
|
* it will set it back to block mode
|
||||||
@ -163,8 +163,8 @@ class Socket {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief bind the socket to an address
|
* \brief bind the socket to an address
|
||||||
* \param addr
|
* \param addr
|
||||||
*/
|
*/
|
||||||
inline void Bind(const SockAddr &addr) {
|
inline void Bind(const SockAddr &addr) {
|
||||||
@ -173,7 +173,7 @@ class Socket {
|
|||||||
Socket::Error("Bind");
|
Socket::Error("Bind");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief try bind the socket to host, from start_port to end_port
|
* \brief try bind the socket to host, from start_port to end_port
|
||||||
* \param start_port starting port number to try
|
* \param start_port starting port number to try
|
||||||
* \param end_port ending port number to try
|
* \param end_port ending port number to try
|
||||||
@ -188,11 +188,11 @@ class Socket {
|
|||||||
return port;
|
return port;
|
||||||
}
|
}
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
if (WSAGetLastError() != WSAEADDRINUSE) {
|
if (WSAGetLastError() != WSAEADDRINUSE) {
|
||||||
Socket::Error("TryBindHost");
|
Socket::Error("TryBindHost");
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
if (errno != EADDRINUSE) {
|
if (errno != EADDRINUSE) {
|
||||||
Socket::Error("TryBindHost");
|
Socket::Error("TryBindHost");
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@ -248,7 +248,7 @@ class Socket {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief a wrapper of TCP socket that hopefully be cross platform
|
* \brief a wrapper of TCP socket that hopefully be cross platform
|
||||||
*/
|
*/
|
||||||
class TCPSocket : public Socket{
|
class TCPSocket : public Socket{
|
||||||
@ -261,10 +261,11 @@ class TCPSocket : public Socket{
|
|||||||
/*!
|
/*!
|
||||||
* \brief enable/disable TCP keepalive
|
* \brief enable/disable TCP keepalive
|
||||||
* \param keepalive whether to set the keep alive option on
|
* \param keepalive whether to set the keep alive option on
|
||||||
*/
|
*/
|
||||||
inline void SetKeepAlive(bool keepalive) {
|
inline void SetKeepAlive(bool keepalive) {
|
||||||
int opt = static_cast<int>(keepalive);
|
int opt = static_cast<int>(keepalive);
|
||||||
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
|
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
|
||||||
|
reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
|
||||||
Socket::Error("SetKeepAlive");
|
Socket::Error("SetKeepAlive");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -294,12 +295,12 @@ class TCPSocket : public Socket{
|
|||||||
return TCPSocket(newfd);
|
return TCPSocket(newfd);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief decide whether the socket is at OOB mark
|
* \brief decide whether the socket is at OOB mark
|
||||||
* \return 1 if at mark, 0 if not, -1 if an error occured
|
* \return 1 if at mark, 0 if not, -1 if an error occured
|
||||||
*/
|
*/
|
||||||
inline int AtMark(void) const {
|
inline int AtMark(void) const {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
unsigned long atmark;
|
unsigned long atmark; // NOLINT(*)
|
||||||
if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
|
if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
|
||||||
#else
|
#else
|
||||||
int atmark;
|
int atmark;
|
||||||
@ -307,8 +308,8 @@ class TCPSocket : public Socket{
|
|||||||
#endif
|
#endif
|
||||||
return static_cast<int>(atmark);
|
return static_cast<int>(atmark);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief connect to an address
|
* \brief connect to an address
|
||||||
* \param addr the address to connect to
|
* \param addr the address to connect to
|
||||||
* \return whether connect is successful
|
* \return whether connect is successful
|
||||||
*/
|
*/
|
||||||
@ -328,8 +329,8 @@ class TCPSocket : public Socket{
|
|||||||
const char *buf = reinterpret_cast<const char*>(buf_);
|
const char *buf = reinterpret_cast<const char*>(buf_);
|
||||||
return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
|
return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief receive data using the socket
|
* \brief receive data using the socket
|
||||||
* \param buf_ the pointer to the buffer
|
* \param buf_ the pointer to the buffer
|
||||||
* \param len the size of the buffer
|
* \param len the size of the buffer
|
||||||
* \param flags extra flags
|
* \param flags extra flags
|
||||||
@ -385,7 +386,7 @@ class TCPSocket : public Socket{
|
|||||||
return ndone;
|
return ndone;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief send a string over network
|
* \brief send a string over network
|
||||||
* \param str the string to be sent
|
* \param str the string to be sent
|
||||||
*/
|
*/
|
||||||
inline void SendStr(const std::string &str) {
|
inline void SendStr(const std::string &str) {
|
||||||
@ -423,7 +424,7 @@ struct SelectHelper {
|
|||||||
maxfd = 0;
|
maxfd = 0;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief add file descriptor to watch for read
|
* \brief add file descriptor to watch for read
|
||||||
* \param fd file descriptor to be watched
|
* \param fd file descriptor to be watched
|
||||||
*/
|
*/
|
||||||
inline void WatchRead(SOCKET fd) {
|
inline void WatchRead(SOCKET fd) {
|
||||||
@ -473,7 +474,7 @@ struct SelectHelper {
|
|||||||
* \param timeout the timeout counter, can be 0, which means wait until the event happen
|
* \param timeout the timeout counter, can be 0, which means wait until the event happen
|
||||||
* \return 1 if success, 0 if timeout, and -1 if error occurs
|
* \return 1 if success, 0 if timeout, and -1 if error occurs
|
||||||
*/
|
*/
|
||||||
inline static int WaitExcept(SOCKET fd, long timeout = 0) {
|
inline static int WaitExcept(SOCKET fd, long timeout = 0) { // NOLINT(*)
|
||||||
fd_set wait_set;
|
fd_set wait_set;
|
||||||
FD_ZERO(&wait_set);
|
FD_ZERO(&wait_set);
|
||||||
FD_SET(fd, &wait_set);
|
FD_SET(fd, &wait_set);
|
||||||
@ -486,10 +487,10 @@ struct SelectHelper {
|
|||||||
* \param select_write whether to watch for write event
|
* \param select_write whether to watch for write event
|
||||||
* \param select_except whether to watch for exception event
|
* \param select_except whether to watch for exception event
|
||||||
* \param timeout specify timeout in micro-seconds(ms) if equals 0, means select will always block
|
* \param timeout specify timeout in micro-seconds(ms) if equals 0, means select will always block
|
||||||
* \return number of active descriptors selected,
|
* \return number of active descriptors selected,
|
||||||
* return -1 if error occurs
|
* return -1 if error occurs
|
||||||
*/
|
*/
|
||||||
inline int Select(long timeout = 0) {
|
inline int Select(long timeout = 0) { // NOLINT(*)
|
||||||
int ret = Select_(static_cast<int>(maxfd + 1),
|
int ret = Select_(static_cast<int>(maxfd + 1),
|
||||||
&read_set, &write_set, &except_set, timeout);
|
&read_set, &write_set, &except_set, timeout);
|
||||||
if (ret == -1) {
|
if (ret == -1) {
|
||||||
@ -500,7 +501,7 @@ struct SelectHelper {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
inline static int Select_(int maxfd, fd_set *rfds,
|
inline static int Select_(int maxfd, fd_set *rfds,
|
||||||
fd_set *wfds, fd_set *efds, long timeout) {
|
fd_set *wfds, fd_set *efds, long timeout) { // NOLINT(*)
|
||||||
#if !defined(_WIN32)
|
#if !defined(_WIN32)
|
||||||
utils::Assert(maxfd < FD_SETSIZE, "maxdf must be smaller than FDSETSIZE");
|
utils::Assert(maxfd < FD_SETSIZE, "maxdf must be smaller than FDSETSIZE");
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# this is a makefile used to show testcases of rabit
|
# this is a makefile used to show testcases of rabit
|
||||||
.PHONY: all
|
.PHONY: all
|
||||||
|
|
||||||
all:
|
all: model_recover_10_10k model_recover_10_10k_die_same
|
||||||
|
|
||||||
# this experiment test recovery with actually process exit, use keepalive to keep program alive
|
# this experiment test recovery with actually process exit, use keepalive to keep program alive
|
||||||
model_recover_10_10k:
|
model_recover_10_10k:
|
||||||
|
|||||||
@ -3,6 +3,7 @@ Python interface for rabit
|
|||||||
Reliable Allreduce and Broadcast Library
|
Reliable Allreduce and Broadcast Library
|
||||||
Author: Tianqi Chen
|
Author: Tianqi Chen
|
||||||
"""
|
"""
|
||||||
|
# pylint: disable=unused-argument,invalid-name,global-statement,dangerous-default-value,
|
||||||
import cPickle as pickle
|
import cPickle as pickle
|
||||||
import ctypes
|
import ctypes
|
||||||
import os
|
import os
|
||||||
@ -17,10 +18,12 @@ else:
|
|||||||
rbtlib = None
|
rbtlib = None
|
||||||
|
|
||||||
# load in xgboost library
|
# load in xgboost library
|
||||||
def loadlib__(lib = 'standard'):
|
def loadlib__(lib='standard'):
|
||||||
|
"""Load rabit library"""
|
||||||
global rbtlib
|
global rbtlib
|
||||||
if rbtlib != None:
|
if rbtlib != None:
|
||||||
warnings.Warn('rabit.int call was ignored because it has already been initialized', level = 2)
|
warnings.warn('rabit.int call was ignored because it has'\
|
||||||
|
' already been initialized', level=2)
|
||||||
return
|
return
|
||||||
if lib == 'standard':
|
if lib == 'standard':
|
||||||
rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH % '')
|
rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH % '')
|
||||||
@ -35,6 +38,7 @@ def loadlib__(lib = 'standard'):
|
|||||||
rbtlib.RabitVersionNumber.restype = ctypes.c_int
|
rbtlib.RabitVersionNumber.restype = ctypes.c_int
|
||||||
|
|
||||||
def unloadlib__():
|
def unloadlib__():
|
||||||
|
"""Unload rabit library"""
|
||||||
global rbtlib
|
global rbtlib
|
||||||
del rbtlib
|
del rbtlib
|
||||||
rbtlib = None
|
rbtlib = None
|
||||||
@ -45,13 +49,13 @@ MIN = 1
|
|||||||
SUM = 2
|
SUM = 2
|
||||||
BITOR = 3
|
BITOR = 3
|
||||||
|
|
||||||
def check_err__():
|
def check_err__():
|
||||||
"""
|
"""
|
||||||
reserved function used to check error
|
reserved function used to check error
|
||||||
"""
|
"""
|
||||||
return
|
return
|
||||||
|
|
||||||
def init(args = sys.argv, lib = 'standard'):
|
def init(args=sys.argv, lib='standard'):
|
||||||
"""
|
"""
|
||||||
intialize the rabit module, call this once before using anything
|
intialize the rabit module, call this once before using anything
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -69,7 +73,7 @@ def init(args = sys.argv, lib = 'standard'):
|
|||||||
|
|
||||||
def finalize():
|
def finalize():
|
||||||
"""
|
"""
|
||||||
finalize the rabit engine, call this function after you finished all jobs
|
finalize the rabit engine, call this function after you finished all jobs
|
||||||
"""
|
"""
|
||||||
rbtlib.RabitFinalize()
|
rbtlib.RabitFinalize()
|
||||||
check_err__()
|
check_err__()
|
||||||
@ -132,7 +136,7 @@ def broadcast(data, root):
|
|||||||
print '@node[%d] after-broadcast: s=\"%s\"' % (rank, str(s))
|
print '@node[%d] after-broadcast: s=\"%s\"' % (rank, str(s))
|
||||||
rabit.finalize()
|
rabit.finalize()
|
||||||
```
|
```
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
data: anytype that can be pickled
|
data: anytype that can be pickled
|
||||||
input data, if current rank does not equal root, this can be None
|
input data, if current rank does not equal root, this can be None
|
||||||
@ -145,12 +149,12 @@ def broadcast(data, root):
|
|||||||
length = ctypes.c_ulong()
|
length = ctypes.c_ulong()
|
||||||
if root == rank:
|
if root == rank:
|
||||||
assert data is not None, 'need to pass in data when broadcasting'
|
assert data is not None, 'need to pass in data when broadcasting'
|
||||||
s = pickle.dumps(data, protocol = pickle.HIGHEST_PROTOCOL)
|
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
length.value = len(s)
|
length.value = len(s)
|
||||||
# run first broadcast
|
# run first broadcast
|
||||||
rbtlib.RabitBroadcast(ctypes.byref(length),
|
rbtlib.RabitBroadcast(ctypes.byref(length),
|
||||||
ctypes.sizeof(ctypes.c_ulong),
|
ctypes.sizeof(ctypes.c_ulong),
|
||||||
root)
|
root)
|
||||||
check_err__()
|
check_err__()
|
||||||
if root != rank:
|
if root != rank:
|
||||||
dptr = (ctypes.c_char * length.value)()
|
dptr = (ctypes.c_char * length.value)()
|
||||||
@ -179,18 +183,19 @@ DTYPE_ENUM__ = {
|
|||||||
np.dtype('float64') : 7
|
np.dtype('float64') : 7
|
||||||
}
|
}
|
||||||
|
|
||||||
def allreduce(data, op, prepare_fun = None):
|
def allreduce(data, op, prepare_fun=None):
|
||||||
"""
|
"""
|
||||||
perform allreduce, return the result, this function is not thread-safe
|
perform allreduce, return the result, this function is not thread-safe
|
||||||
Arguments:
|
Arguments:
|
||||||
data: numpy ndarray
|
data: numpy ndarray
|
||||||
input data
|
input data
|
||||||
op: int
|
op: int
|
||||||
reduction operators, can be MIN, MAX, SUM, BITOR
|
reduction operators, can be MIN, MAX, SUM, BITOR
|
||||||
prepare_fun: lambda data
|
prepare_fun: lambda data
|
||||||
Lazy preprocessing function, if it is not None, prepare_fun(data)
|
Lazy preprocessing function, if it is not None, prepare_fun(data)
|
||||||
will be called by the function before performing allreduce, to intialize the data
|
will be called by the function before performing allreduce, to intialize the data
|
||||||
If the result of Allreduce can be recovered directly, then prepare_fun will NOT be called
|
If the result of Allreduce can be recovered directly,
|
||||||
|
then prepare_fun will NOT be called
|
||||||
Returns:
|
Returns:
|
||||||
the result of allreduce, have same shape as data
|
the result of allreduce, have same shape as data
|
||||||
"""
|
"""
|
||||||
@ -206,12 +211,13 @@ def allreduce(data, op, prepare_fun = None):
|
|||||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||||
op, None, None)
|
op, None, None)
|
||||||
else:
|
else:
|
||||||
PFUNC = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
||||||
def pfunc(args):
|
def pfunc(args):
|
||||||
|
"""prepare function."""
|
||||||
prepare_fun(data)
|
prepare_fun(data)
|
||||||
rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||||
op, PFUNC(pfunc), None)
|
op, func_ptr(pfunc), None)
|
||||||
check_err__()
|
check_err__()
|
||||||
return buf
|
return buf
|
||||||
|
|
||||||
@ -229,49 +235,49 @@ def load_model__(ptr, length):
|
|||||||
data = (ctypes.c_char * length).from_address(ctypes.addressof(ptr.contents))
|
data = (ctypes.c_char * length).from_address(ctypes.addressof(ptr.contents))
|
||||||
return pickle.loads(data.raw)
|
return pickle.loads(data.raw)
|
||||||
|
|
||||||
def load_checkpoint(with_local = False):
|
def load_checkpoint(with_local=False):
|
||||||
"""
|
"""
|
||||||
load latest check point
|
load latest check point
|
||||||
Arguments:
|
Arguments:
|
||||||
with_local: boolean [default = False]
|
with_local: boolean [default = False]
|
||||||
whether the checkpoint contains local model
|
whether the checkpoint contains local model
|
||||||
Returns:
|
Returns:
|
||||||
if with_local: return (version, gobal_model, local_model)
|
if with_local: return (version, gobal_model, local_model)
|
||||||
else return (version, gobal_model)
|
else return (version, gobal_model)
|
||||||
if returned version == 0, this means no model has been CheckPointed
|
if returned version == 0, this means no model has been CheckPointed
|
||||||
and global_model, local_model returned will be None
|
and global_model, local_model returned will be None
|
||||||
"""
|
"""
|
||||||
gp = ctypes.POINTER(ctypes.c_char)()
|
gptr = ctypes.POINTER(ctypes.c_char)()
|
||||||
global_len = ctypes.c_ulong()
|
global_len = ctypes.c_ulong()
|
||||||
if with_local:
|
if with_local:
|
||||||
lp = ctypes.POINTER(ctypes.c_char)()
|
lptr = ctypes.POINTER(ctypes.c_char)()
|
||||||
local_len = ctypes.c_ulong()
|
local_len = ctypes.c_ulong()
|
||||||
version = rbtlib.RabitLoadCheckPoint(
|
version = rbtlib.RabitLoadCheckPoint(
|
||||||
ctypes.byref(gp),
|
ctypes.byref(gptr),
|
||||||
ctypes.byref(global_len),
|
ctypes.byref(global_len),
|
||||||
ctypes.byref(lp),
|
ctypes.byref(lptr),
|
||||||
ctypes.byref(local_len))
|
ctypes.byref(local_len))
|
||||||
check_err__()
|
check_err__()
|
||||||
if version == 0:
|
if version == 0:
|
||||||
return (version, None, None)
|
return (version, None, None)
|
||||||
return (version,
|
return (version,
|
||||||
load_model__(gp, global_len.value),
|
load_model__(gptr, global_len.value),
|
||||||
load_model__(lp, local_len.value))
|
load_model__(lptr, local_len.value))
|
||||||
else:
|
else:
|
||||||
version = rbtlib.RabitLoadCheckPoint(
|
version = rbtlib.RabitLoadCheckPoint(
|
||||||
ctypes.byref(gp),
|
ctypes.byref(gptr),
|
||||||
ctypes.byref(global_len),
|
ctypes.byref(global_len),
|
||||||
None, None)
|
None, None)
|
||||||
check_err__()
|
check_err__()
|
||||||
if version == 0:
|
if version == 0:
|
||||||
return (version, None)
|
return (version, None)
|
||||||
return (version,
|
return (version,
|
||||||
load_model__(gp, global_len.value))
|
load_model__(gptr, global_len.value))
|
||||||
|
|
||||||
def checkpoint(global_model, local_model = None):
|
def checkpoint(global_model, local_model=None):
|
||||||
"""
|
"""
|
||||||
checkpoint the model, meaning we finished a stage of execution
|
checkpoint the model, meaning we finished a stage of execution
|
||||||
every time we call check point, there is a version number which will increase by one
|
every time we call check point, there is a version number which will increase by one
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
global_model: anytype that can be pickled
|
global_model: anytype that can be pickled
|
||||||
@ -285,16 +291,17 @@ def checkpoint(global_model, local_model = None):
|
|||||||
while global_model do not need explicit replication.
|
while global_model do not need explicit replication.
|
||||||
It is recommended to use global_model if possible
|
It is recommended to use global_model if possible
|
||||||
"""
|
"""
|
||||||
sg = pickle.dumps(global_model)
|
sglobal = pickle.dumps(global_model)
|
||||||
if local_model is None:
|
if local_model is None:
|
||||||
rbtlib.RabitCheckPoint(sg, len(sg), None, 0)
|
rbtlib.RabitCheckPoint(sglobal, len(sglobal), None, 0)
|
||||||
check_err__()
|
check_err__()
|
||||||
del sg;
|
del sglobal
|
||||||
else:
|
else:
|
||||||
sl = pickle.dumps(local_model)
|
slocal = pickle.dumps(local_model)
|
||||||
rbtlib.RabitCheckPoint(sg, len(sg), sl, len(sl))
|
rbtlib.RabitCheckPoint(sglobal, len(sglobal), slocal, len(slocal))
|
||||||
check_err__()
|
check_err__()
|
||||||
del sl; del sg;
|
del slocal
|
||||||
|
del sglobal
|
||||||
|
|
||||||
def version_number():
|
def version_number():
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
// Copyright by Contributors
|
||||||
// implementations in ctypes
|
// implementations in ctypes
|
||||||
#define _CRT_SECURE_NO_WARNINGS
|
#define _CRT_SECURE_NO_WARNINGS
|
||||||
#define _CRT_SECURE_NO_DEPRECATE
|
#define _CRT_SECURE_NO_DEPRECATE
|
||||||
@ -28,7 +29,7 @@ struct FHelper<op::BitOR, DType> {
|
|||||||
void (*prepare_fun)(void *arg),
|
void (*prepare_fun)(void *arg),
|
||||||
void *prepare_arg) {
|
void *prepare_arg) {
|
||||||
utils::Error("DataType does not support bitwise or operation");
|
utils::Error("DataType does not support bitwise or operation");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
template<typename OP>
|
template<typename OP>
|
||||||
inline void Allreduce_(void *sendrecvbuf_,
|
inline void Allreduce_(void *sendrecvbuf_,
|
||||||
@ -60,12 +61,12 @@ inline void Allreduce_(void *sendrecvbuf_,
|
|||||||
return;
|
return;
|
||||||
case kLong:
|
case kLong:
|
||||||
rabit::Allreduce<OP>
|
rabit::Allreduce<OP>
|
||||||
(static_cast<long*>(sendrecvbuf_),
|
(static_cast<long*>(sendrecvbuf_), // NOLINT(*)
|
||||||
count, prepare_fun, prepare_arg);
|
count, prepare_fun, prepare_arg);
|
||||||
return;
|
return;
|
||||||
case kULong:
|
case kULong:
|
||||||
rabit::Allreduce<OP>
|
rabit::Allreduce<OP>
|
||||||
(static_cast<unsigned long*>(sendrecvbuf_),
|
(static_cast<unsigned long*>(sendrecvbuf_), // NOLINT(*)
|
||||||
count, prepare_fun, prepare_arg);
|
count, prepare_fun, prepare_arg);
|
||||||
return;
|
return;
|
||||||
case kFloat:
|
case kFloat:
|
||||||
@ -135,7 +136,7 @@ struct ReadWrapper : public Serializable {
|
|||||||
}
|
}
|
||||||
virtual void Save(Stream *fo) const {
|
virtual void Save(Stream *fo) const {
|
||||||
utils::Error("not implemented");
|
utils::Error("not implemented");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
struct WriteWrapper : public Serializable {
|
struct WriteWrapper : public Serializable {
|
||||||
const char *data;
|
const char *data;
|
||||||
@ -179,7 +180,7 @@ extern "C" {
|
|||||||
if (s.length() > max_len) {
|
if (s.length() > max_len) {
|
||||||
s.resize(max_len - 1);
|
s.resize(max_len - 1);
|
||||||
}
|
}
|
||||||
strcpy(out_name, s.c_str());
|
strcpy(out_name, s.c_str()); // NOLINT(*)
|
||||||
*out_len = static_cast<rbt_ulong>(s.length());
|
*out_len = static_cast<rbt_ulong>(s.length());
|
||||||
}
|
}
|
||||||
void RabitBroadcast(void *sendrecv_data,
|
void RabitBroadcast(void *sendrecv_data,
|
||||||
@ -218,7 +219,7 @@ extern "C" {
|
|||||||
*out_local_model = BeginPtr(local_buffer);
|
*out_local_model = BeginPtr(local_buffer);
|
||||||
*out_local_len = static_cast<rbt_ulong>(local_buffer.length());
|
*out_local_len = static_cast<rbt_ulong>(local_buffer.length());
|
||||||
}
|
}
|
||||||
return version;
|
return version;
|
||||||
}
|
}
|
||||||
void RabitCheckPoint(const char *global_model,
|
void RabitCheckPoint(const char *global_model,
|
||||||
rbt_ulong global_len,
|
rbt_ulong global_len,
|
||||||
|
|||||||
@ -1,18 +1,19 @@
|
|||||||
#ifndef RABIT_WRAPPER_H_
|
|
||||||
#define RABIT_WRAPPER_H_
|
|
||||||
/*!
|
/*!
|
||||||
|
* Copyright by Contributors
|
||||||
* \file rabit_wrapper.h
|
* \file rabit_wrapper.h
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
* \brief a C style wrapper of rabit
|
* \brief a C style wrapper of rabit
|
||||||
* can be used to create wrapper of other languages
|
* can be used to create wrapper of other languages
|
||||||
*/
|
*/
|
||||||
|
#ifndef RABIT_WRAPPER_H_
|
||||||
|
#define RABIT_WRAPPER_H_
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
#define RABIT_DLL __declspec(dllexport)
|
#define RABIT_DLL __declspec(dllexport)
|
||||||
#else
|
#else
|
||||||
#define RABIT_DLL
|
#define RABIT_DLL
|
||||||
#endif
|
#endif
|
||||||
// manually define unsign long
|
// manually define unsign long
|
||||||
typedef unsigned long rbt_ulong;
|
typedef unsigned long rbt_ulong; // NOLINT(*)
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
@ -23,8 +24,8 @@ extern "C" {
|
|||||||
* \param argv the array of input arguments
|
* \param argv the array of input arguments
|
||||||
*/
|
*/
|
||||||
RABIT_DLL void RabitInit(int argc, char *argv[]);
|
RABIT_DLL void RabitInit(int argc, char *argv[]);
|
||||||
/*!
|
/*!
|
||||||
* \brief finalize the rabit engine, call this function after you finished all jobs
|
* \brief finalize the rabit engine, call this function after you finished all jobs
|
||||||
*/
|
*/
|
||||||
RABIT_DLL void RabitFinalize(void);
|
RABIT_DLL void RabitFinalize(void);
|
||||||
/*! \brief get rank of current process */
|
/*! \brief get rank of current process */
|
||||||
@ -37,9 +38,9 @@ extern "C" {
|
|||||||
* the user who monitors the tracker
|
* the user who monitors the tracker
|
||||||
* \param msg the message to be printed
|
* \param msg the message to be printed
|
||||||
*/
|
*/
|
||||||
RABIT_DLL void RabitTrackerPrint(const char *msg);
|
RABIT_DLL void RabitTrackerPrint(const char *msg);
|
||||||
/*!
|
/*!
|
||||||
* \brief get name of processor
|
* \brief get name of processor
|
||||||
* \param out_name hold output string
|
* \param out_name hold output string
|
||||||
* \param out_len hold length of output string
|
* \param out_len hold length of output string
|
||||||
* \param max_len maximum buffer length of input
|
* \param max_len maximum buffer length of input
|
||||||
@ -50,7 +51,7 @@ extern "C" {
|
|||||||
/*!
|
/*!
|
||||||
* \brief broadcast an memory region to all others from root
|
* \brief broadcast an memory region to all others from root
|
||||||
*
|
*
|
||||||
* Example: int a = 1; Broadcast(&a, sizeof(a), root);
|
* Example: int a = 1; Broadcast(&a, sizeof(a), root);
|
||||||
* \param sendrecv_data the pointer to send or recive buffer,
|
* \param sendrecv_data the pointer to send or recive buffer,
|
||||||
* \param size the size of the data
|
* \param size the size of the data
|
||||||
* \param root the root of process
|
* \param root the root of process
|
||||||
@ -58,7 +59,7 @@ extern "C" {
|
|||||||
RABIT_DLL void RabitBroadcast(void *sendrecv_data,
|
RABIT_DLL void RabitBroadcast(void *sendrecv_data,
|
||||||
rbt_ulong size, int root);
|
rbt_ulong size, int root);
|
||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf
|
* \brief perform in-place allreduce, on sendrecvbuf
|
||||||
* this function is NOT thread-safe
|
* this function is NOT thread-safe
|
||||||
*
|
*
|
||||||
* Example Usage: the following code gives sum of the result
|
* Example Usage: the following code gives sum of the result
|
||||||
@ -81,14 +82,14 @@ extern "C" {
|
|||||||
int enum_op,
|
int enum_op,
|
||||||
void (*prepare_fun)(void *arg),
|
void (*prepare_fun)(void *arg),
|
||||||
void *prepare_arg);
|
void *prepare_arg);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief load latest check point
|
* \brief load latest check point
|
||||||
* \param out_global_model hold output of serialized global_model
|
* \param out_global_model hold output of serialized global_model
|
||||||
* \param out_global_len the output length of serialized global model
|
* \param out_global_len the output length of serialized global model
|
||||||
* \param out_local_model hold output of serialized local_model, can be NULL
|
* \param out_local_model hold output of serialized local_model, can be NULL
|
||||||
* \param out_local_len the output length of serialized local model, can be NULL
|
* \param out_local_len the output length of serialized local model, can be NULL
|
||||||
*
|
*
|
||||||
* \return the version number of check point loaded
|
* \return the version number of check point loaded
|
||||||
* if returned version == 0, this means no model has been CheckPointed
|
* if returned version == 0, this means no model has been CheckPointed
|
||||||
* nothing will be touched
|
* nothing will be touched
|
||||||
@ -100,7 +101,7 @@ extern "C" {
|
|||||||
/*!
|
/*!
|
||||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||||
* every time we call check point, there is a version number which will increase by one
|
* every time we call check point, there is a version number which will increase by one
|
||||||
*
|
*
|
||||||
* \param global_model hold content of serialized global_model
|
* \param global_model hold content of serialized global_model
|
||||||
* \param global_len the content length of serialized global model
|
* \param global_len the content length of serialized global model
|
||||||
* \param local_model hold content of serialized local_model, can be NULL
|
* \param local_model hold content of serialized local_model, can be NULL
|
||||||
@ -122,4 +123,4 @@ extern "C" {
|
|||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} // C
|
} // C
|
||||||
#endif
|
#endif
|
||||||
#endif // XGBOOST_WRAPPER_H_
|
#endif // RABIT_WRAPPER_H_
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user