Merge rabit
This commit is contained in:
19
rabit/include/rabit/base.h
Normal file
19
rabit/include/rabit/base.h
Normal file
@@ -0,0 +1,19 @@
|
||||
/*!
|
||||
* Copyright (c) 2020 by Contributors
|
||||
* \file base.h
|
||||
* \brief Macros common to all headers
|
||||
*
|
||||
* \author Hyunsu Cho
|
||||
*/
|
||||
|
||||
#ifndef RABIT_BASE_H_
|
||||
#define RABIT_BASE_H_
|
||||
|
||||
#ifndef _CRT_SECURE_NO_WARNINGS
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#endif // _CRT_SECURE_NO_WARNINGS
|
||||
#ifndef _CRT_SECURE_NO_DEPRECATE
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#endif // _CRT_SECURE_NO_DEPRECATE
|
||||
|
||||
#endif // RABIT_BASE_H_
|
||||
196
rabit/include/rabit/c_api.h
Normal file
196
rabit/include/rabit/c_api.h
Normal file
@@ -0,0 +1,196 @@
|
||||
/*!
|
||||
* Copyright by Contributors
|
||||
* \file c_api.h
|
||||
* \author Tianqi Chen
|
||||
* \brief a C style API of rabit.
|
||||
*/
|
||||
#ifndef RABIT_C_API_H_
|
||||
#define RABIT_C_API_H_
|
||||
|
||||
#ifdef __cplusplus
|
||||
#define RABIT_EXTERN_C extern "C"
|
||||
#include <cstdio>
|
||||
#else
|
||||
#define RABIT_EXTERN_C
|
||||
#include <stdio.h>
|
||||
#endif // __cplusplus
|
||||
|
||||
#if defined(_MSC_VER) || defined(_WIN32)
|
||||
#define RABIT_DLL RABIT_EXTERN_C __declspec(dllexport)
|
||||
#else
|
||||
#define RABIT_DLL RABIT_EXTERN_C __attribute__ ((visibility ("default")))
|
||||
#endif // defined(_MSC_VER) || defined(_WIN32)
|
||||
|
||||
/*! \brief rabit unsigned long type */
|
||||
typedef unsigned long rbt_ulong; // NOLINT(*)
|
||||
|
||||
/*!
|
||||
* \brief intialize the rabit module,
|
||||
* call this once before using anything
|
||||
* The additional arguments is not necessary.
|
||||
* Usually rabit will detect settings
|
||||
* from environment variables.
|
||||
* \param argc number of arguments in argv
|
||||
* \param argv the array of input arguments
|
||||
* \return true if rabit is initialized successfully otherwise false
|
||||
*/
|
||||
RABIT_DLL bool RabitInit(int argc, char *argv[]);
|
||||
|
||||
/*!
|
||||
* \brief finalize the rabit engine,
|
||||
* call this function after you finished all jobs.
|
||||
* \return true if rabit is initialized successfully otherwise false
|
||||
*/
|
||||
RABIT_DLL bool RabitFinalize(void);
|
||||
|
||||
/*!
|
||||
* \brief get rank of previous process in ring topology
|
||||
* \return rank number of worker
|
||||
* */
|
||||
RABIT_DLL int RabitGetRingPrevRank(void);
|
||||
|
||||
/*!
|
||||
* \brief get rank of current process
|
||||
* \return rank number of worker
|
||||
* */
|
||||
RABIT_DLL int RabitGetRank(void);
|
||||
|
||||
/*!
|
||||
* \brief get total number of process
|
||||
* \return total world size
|
||||
* */
|
||||
RABIT_DLL int RabitGetWorldSize(void);
|
||||
|
||||
/*!
|
||||
* \brief get rank of current process
|
||||
* \return if rabit is distributed
|
||||
* */
|
||||
RABIT_DLL int RabitIsDistributed(void);
|
||||
|
||||
/*!
|
||||
* \brief print the msg to the tracker,
|
||||
* this function can be used to communicate the information of the progress to
|
||||
* the user who monitors the tracker
|
||||
* \param msg the message to be printed
|
||||
*/
|
||||
RABIT_DLL void RabitTrackerPrint(const char *msg);
|
||||
/*!
|
||||
* \brief get name of processor
|
||||
* \param out_name hold output string
|
||||
* \param out_len hold length of output string
|
||||
* \param max_len maximum buffer length of input
|
||||
*/
|
||||
RABIT_DLL void RabitGetProcessorName(char *out_name,
|
||||
rbt_ulong *out_len,
|
||||
rbt_ulong max_len);
|
||||
/*!
|
||||
* \brief broadcast an memory region to all others from root
|
||||
*
|
||||
* Example: int a = 1; Broadcast(&a, sizeof(a), root);
|
||||
* \param sendrecv_data the pointer to send or recive buffer,
|
||||
* \param size the size of the data
|
||||
* \param root the root of process
|
||||
*/
|
||||
RABIT_DLL void RabitBroadcast(void *sendrecv_data,
|
||||
rbt_ulong size, int root);
|
||||
|
||||
/*!
|
||||
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param beginIndex beginning of the current slice in sendrecvbuf of type enum_dtype
|
||||
* \param size_node_slice size of the current node slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \param enum_dtype the enumeration of data type, see rabit::engine::mpi::DataType in engine.h of rabit include
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
RABIT_DLL void RabitAllgather(void *sendrecvbuf,
|
||||
size_t total_size,
|
||||
size_t beginIndex,
|
||||
size_t size_node_slice,
|
||||
size_t size_prev_slice,
|
||||
int enum_dtype);
|
||||
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
*
|
||||
* Example Usage: the following code gives sum of the result
|
||||
* vector<int> data(10);
|
||||
* ...
|
||||
* Allreduce<op::Sum>(&data[0], data.size());
|
||||
* ...
|
||||
* \param sendrecvbuf buffer for both sending and recving data
|
||||
* \param count number of elements to be reduced
|
||||
* \param enum_dtype the enumeration of data type, see rabit::engine::mpi::DataType in engine.h of rabit include
|
||||
* \param enum_op the enumeration of operation type, see rabit::engine::mpi::OpType in engine.h of rabit
|
||||
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
*/
|
||||
RABIT_DLL void RabitAllreduce(void *sendrecvbuf,
|
||||
size_t count,
|
||||
int enum_dtype,
|
||||
int enum_op,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg);
|
||||
|
||||
/*!
|
||||
* \brief load latest check point
|
||||
* \param out_global_model hold output of serialized global_model
|
||||
* \param out_global_len the output length of serialized global model
|
||||
* \param out_local_model hold output of serialized local_model, can be NULL
|
||||
* \param out_local_len the output length of serialized local model, can be NULL
|
||||
*
|
||||
* \return the version number of check point loaded
|
||||
* if returned version == 0, this means no model has been CheckPointed
|
||||
* nothing will be touched
|
||||
*/
|
||||
RABIT_DLL int RabitLoadCheckPoint(char **out_global_model,
|
||||
rbt_ulong *out_global_len,
|
||||
char **out_local_model,
|
||||
rbt_ulong *out_local_len);
|
||||
/*!
|
||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||
* every time we call check point, there is a version number which will increase by one
|
||||
*
|
||||
* \param global_model hold content of serialized global_model
|
||||
* \param global_len the content length of serialized global model
|
||||
* \param local_model hold content of serialized local_model, can be NULL
|
||||
* \param local_len the content length of serialized local model, can be NULL
|
||||
*
|
||||
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
|
||||
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
|
||||
* So only CheckPoint with global_model if possible
|
||||
*/
|
||||
RABIT_DLL void RabitCheckPoint(const char *global_model,
|
||||
rbt_ulong global_len,
|
||||
const char *local_model,
|
||||
rbt_ulong local_len);
|
||||
/*!
|
||||
* \return version number of current stored model,
|
||||
* which means how many calls to CheckPoint we made so far
|
||||
* \return rabit version number
|
||||
*/
|
||||
RABIT_DLL int RabitVersionNumber(void);
|
||||
|
||||
|
||||
/*!
|
||||
* \brief a Dummy function,
|
||||
* used to cause force link of C API into the DLL.
|
||||
* \code
|
||||
* \/\/force link rabit C API library.
|
||||
* static int must_link_rabit_ = RabitLinkTag();
|
||||
* \endcode
|
||||
* \return a dummy integer.
|
||||
*/
|
||||
RABIT_DLL int RabitLinkTag(void);
|
||||
|
||||
#endif // RABIT_C_API_H_
|
||||
346
rabit/include/rabit/internal/engine.h
Normal file
346
rabit/include/rabit/internal/engine.h
Normal file
@@ -0,0 +1,346 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file engine.h
|
||||
* \brief This file defines the core interface of rabit library
|
||||
* \author Tianqi Chen, Nacho, Tianyi
|
||||
*/
|
||||
#ifndef RABIT_INTERNAL_ENGINE_H_
|
||||
#define RABIT_INTERNAL_ENGINE_H_
|
||||
#include <string>
|
||||
#include "rabit/serializable.h"
|
||||
|
||||
#if (defined(__GNUC__) && !defined(__clang__))
|
||||
#define _FILE __builtin_FILE()
|
||||
#define _LINE __builtin_LINE()
|
||||
#define _CALLER __builtin_FUNCTION()
|
||||
#else
|
||||
#define _FILE "N/A"
|
||||
#define _LINE -1
|
||||
#define _CALLER "N/A"
|
||||
#endif // (defined(__GNUC__) && !defined(__clang__))
|
||||
|
||||
namespace MPI {
|
||||
/*! \brief MPI data type just to be compatible with MPI reduce function*/
|
||||
class Datatype;
|
||||
}
|
||||
|
||||
/*! \brief namespace of rabit */
|
||||
namespace rabit {
|
||||
/*! \brief core interface of the engine */
|
||||
namespace engine {
|
||||
/*! \brief interface of core Allreduce engine */
|
||||
class IEngine {
|
||||
public:
|
||||
/*!
|
||||
* \brief Preprocessing function, that is called before AllReduce,
|
||||
* used to prepare the data used by AllReduce
|
||||
* \param arg additional possible argument used to invoke the preprocessor
|
||||
*/
|
||||
typedef void (PreprocFunction) (void *arg);
|
||||
/*!
|
||||
* \brief reduce function, the same form of MPI reduce function is used,
|
||||
* to be compatible with MPI interface
|
||||
* In all the functions, the memory is ensured to aligned to 64-bit
|
||||
* which means it is OK to cast src,dst to double* int* etc
|
||||
* \param src pointer to source space
|
||||
* \param dst pointer to destination reduction
|
||||
* \param count total number of elements to be reduced (note this is total number of elements instead of bytes)
|
||||
* the definition of the reduce function should be type aware
|
||||
* \param dtype the data type object, to be compatible with MPI reduce
|
||||
*/
|
||||
typedef void (ReduceFunction) (const void *src,
|
||||
void *dst, int count,
|
||||
const MPI::Datatype &dtype);
|
||||
/*! \brief virtual destructor */
|
||||
virtual ~IEngine() {}
|
||||
/*!
|
||||
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Allgather(void *sendrecvbuf,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER) = 0;
|
||||
/*!
|
||||
* \brief performs in-place Allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||
* \param type_nbytes the number of bytes the type has
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to pass into the lazy preprocessing function
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun = NULL,
|
||||
void *prepare_arg = NULL,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER) = 0;
|
||||
/*!
|
||||
* \brief broadcasts data from root to every other node
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||
* \param size the size of the data to be broadcasted
|
||||
* \param root the root worker id to broadcast the data
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER) = 0;
|
||||
/*!
|
||||
* \brief explicitly re-initialize everything before calling LoadCheckPoint
|
||||
* call this function when IEngine throws an exception,
|
||||
* this function should only be used for test purposes
|
||||
*/
|
||||
virtual void InitAfterException(void) = 0;
|
||||
/*!
|
||||
* \brief loads the latest check point
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller needs to guarantee that the global_model
|
||||
* is the same in all nodes
|
||||
* \param local_model pointer to the local model that is specific to current node/rank
|
||||
* this can be NULL when no local model is needed
|
||||
*
|
||||
* \return the version number of the model loaded
|
||||
* if returned version == 0, this means no model has been CheckPointed
|
||||
* the p_model is not touched, users should do necessary initialization by themselves
|
||||
*
|
||||
* Common usage example:
|
||||
* int iter = rabit::LoadCheckPoint(&model);
|
||||
* if (iter == 0) model.InitParameters();
|
||||
* for (i = iter; i < max_iter; ++i) {
|
||||
* do many things, include allreduce
|
||||
* rabit::CheckPoint(model);
|
||||
* }
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model = NULL) = 0;
|
||||
/*!
|
||||
* \brief checkpoints the model, meaning a stage of execution was finished
|
||||
* every time we call check point, a version number increases by ones
|
||||
*
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller needs to guarantee that the global_model
|
||||
* is the same in every node
|
||||
* \param local_model pointer to the local model that is specific to current node/rank
|
||||
* this can be NULL when no local state is needed
|
||||
*
|
||||
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
|
||||
* bring replication cost in CheckPoint function. global_model does not need explicit replication.
|
||||
* So, only CheckPoint with global_model if possible
|
||||
*
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model = NULL) = 0;
|
||||
/*!
|
||||
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||
* when certain condition is met (see detailed explanation).
|
||||
*
|
||||
* This is a "lazy" checkpoint such that only the pointer to global_model is
|
||||
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
||||
* The global_model must remain unchanged until the last call of Allreduce/Broadcast in the current version finishes.
|
||||
* In other words, global_model can be changed only between the last call of
|
||||
* Allreduce/Broadcast and LazyCheckPoint in the current version
|
||||
*
|
||||
* For example, suppose the calling sequence is:
|
||||
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
|
||||
*
|
||||
* If the user can only change global_model in code3, then LazyCheckPoint can be used to
|
||||
* improve the efficiency of the program.
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller needs to guarantee that global_model
|
||||
* is the same in every node
|
||||
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void LazyCheckPoint(const Serializable *global_model) = 0;
|
||||
/*!
|
||||
* \return version number of the current stored model,
|
||||
* which means how many calls to CheckPoint we made so far
|
||||
* \sa LoadCheckPoint, CheckPoint
|
||||
*/
|
||||
virtual int VersionNumber(void) const = 0;
|
||||
/*! \brief gets rank of previous node in ring topology */
|
||||
virtual int GetRingPrevRank(void) const = 0;
|
||||
/*! \brief gets rank of current node */
|
||||
virtual int GetRank(void) const = 0;
|
||||
/*! \brief gets total number of nodes */
|
||||
virtual int GetWorldSize(void) const = 0;
|
||||
/*! \brief whether we run in distribted mode */
|
||||
virtual bool IsDistributed(void) const = 0;
|
||||
/*! \brief gets the host name of the current node */
|
||||
virtual std::string GetHost(void) const = 0;
|
||||
/*!
|
||||
* \brief prints the msg in the tracker,
|
||||
* this function can be used to communicate progress information to
|
||||
* the user who monitors the tracker
|
||||
* \param msg message to be printed in the tracker
|
||||
*/
|
||||
virtual void TrackerPrint(const std::string &msg) = 0;
|
||||
};
|
||||
|
||||
/*! \brief initializes the engine module */
|
||||
bool Init(int argc, char *argv[]);
|
||||
/*! \brief finalizes the engine module */
|
||||
bool Finalize(void);
|
||||
/*! \brief singleton method to get engine */
|
||||
IEngine *GetEngine(void);
|
||||
|
||||
/*! \brief namespace that contains stubs to be compatible with MPI */
|
||||
namespace mpi {
|
||||
/*!\brief enum of all operators */
|
||||
enum OpType {
|
||||
kMax = 0,
|
||||
kMin = 1,
|
||||
kSum = 2,
|
||||
kBitwiseOR = 3
|
||||
};
|
||||
/*!\brief enum of supported data types */
|
||||
enum DataType {
|
||||
kChar = 0,
|
||||
kUChar = 1,
|
||||
kInt = 2,
|
||||
kUInt = 3,
|
||||
kLong = 4,
|
||||
kULong = 5,
|
||||
kFloat = 6,
|
||||
kDouble = 7,
|
||||
kLongLong = 8,
|
||||
kULongLong = 9
|
||||
};
|
||||
} // namespace mpi
|
||||
/*!
|
||||
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
void Allgather(void* sendrecvbuf,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
/*!
|
||||
* \brief perform in-place Allreduce, on sendrecvbuf
|
||||
* this is an internal function used by rabit to be able to compile with MPI
|
||||
* do not use this function directly
|
||||
* \param sendrecvbuf buffer for both sending and receiving data
|
||||
* \param type_nbytes the number of bytes the type has
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \param dtype the data type
|
||||
* \param op the reduce operator type
|
||||
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to pass into the lazy preprocessing function.
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
void Allreduce_(void *sendrecvbuf,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
IEngine::ReduceFunction red,
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun = NULL,
|
||||
void *prepare_arg = NULL,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
/*!
|
||||
* \brief handle for customized reducer, used to handle customized reduce
|
||||
* this class is mainly created for compatiblity issues with MPI's customized reduce
|
||||
*/
|
||||
class ReduceHandle {
|
||||
public:
|
||||
// constructor
|
||||
ReduceHandle(void);
|
||||
// destructor
|
||||
~ReduceHandle(void);
|
||||
/*!
|
||||
* \brief initialize the reduce function,
|
||||
* with the type the reduce function needs to deal with
|
||||
* the reduce function MUST be communicative
|
||||
*/
|
||||
void Init(IEngine::ReduceFunction redfunc, size_t type_nbytes);
|
||||
/*!
|
||||
* \brief customized in-place all reduce operation
|
||||
* \param sendrecvbuf the in place send-recv buffer
|
||||
* \param type_n4bytes size of the type, in terms of 4bytes
|
||||
* \param count number of elements to send
|
||||
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to pass into the lazy preprocessing function
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
void Allreduce(void *sendrecvbuf,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
IEngine::PreprocFunction prepare_fun = NULL,
|
||||
void *prepare_arg = NULL,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
/*! \return the number of bytes occupied by the type */
|
||||
static int TypeSize(const MPI::Datatype &dtype);
|
||||
|
||||
protected:
|
||||
// handle function field
|
||||
void *handle_;
|
||||
// reduce function of the reducer
|
||||
IEngine::ReduceFunction *redfunc_;
|
||||
// handle to the type field
|
||||
void *htype_;
|
||||
// the created type in 4 bytes
|
||||
size_t created_type_nbytes_;
|
||||
};
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
#endif // RABIT_INTERNAL_ENGINE_H_
|
||||
114
rabit/include/rabit/internal/io.h
Normal file
114
rabit/include/rabit/internal/io.h
Normal file
@@ -0,0 +1,114 @@
|
||||
/*!
|
||||
* Copyright (c) 2014-2019 by Contributors
|
||||
* \file io.h
|
||||
* \brief utilities with different serializable implementations
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_INTERNAL_IO_H_
|
||||
#define RABIT_INTERNAL_IO_H_
|
||||
#include <cstdio>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <limits>
|
||||
#include "rabit/internal/utils.h"
|
||||
#include "rabit/serializable.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace utils {
|
||||
/*! \brief re-use definition of dmlc::SeekStream */
|
||||
typedef dmlc::SeekStream SeekStream;
|
||||
/*! \brief fixed size memory buffer */
|
||||
struct MemoryFixSizeBuffer : public SeekStream {
|
||||
public:
|
||||
// similar to SEEK_END in libc
|
||||
static size_t constexpr SeekEnd = std::numeric_limits<size_t>::max();
|
||||
|
||||
public:
|
||||
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
|
||||
: p_buffer_(reinterpret_cast<char*>(p_buffer)),
|
||||
buffer_size_(buffer_size) {
|
||||
curr_ptr_ = 0;
|
||||
}
|
||||
virtual ~MemoryFixSizeBuffer(void) {}
|
||||
virtual size_t Read(void *ptr, size_t size) {
|
||||
size_t nread = std::min(buffer_size_ - curr_ptr_, size);
|
||||
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
|
||||
curr_ptr_ += nread;
|
||||
return nread;
|
||||
}
|
||||
virtual void Write(const void *ptr, size_t size) {
|
||||
if (size == 0) return;
|
||||
utils::Assert(curr_ptr_ + size <= buffer_size_,
|
||||
"write position exceed fixed buffer size");
|
||||
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
|
||||
curr_ptr_ += size;
|
||||
}
|
||||
virtual void Seek(size_t pos) {
|
||||
if (pos == SeekEnd) {
|
||||
curr_ptr_ = buffer_size_;
|
||||
} else {
|
||||
curr_ptr_ = static_cast<size_t>(pos);
|
||||
}
|
||||
}
|
||||
virtual size_t Tell(void) {
|
||||
return curr_ptr_;
|
||||
}
|
||||
virtual bool AtEnd(void) const {
|
||||
return curr_ptr_ == buffer_size_;
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief in memory buffer */
|
||||
char *p_buffer_;
|
||||
/*! \brief current pointer */
|
||||
size_t buffer_size_;
|
||||
/*! \brief current pointer */
|
||||
size_t curr_ptr_;
|
||||
}; // class MemoryFixSizeBuffer
|
||||
|
||||
/*! \brief a in memory buffer that can be read and write as stream interface */
|
||||
struct MemoryBufferStream : public SeekStream {
|
||||
public:
|
||||
explicit MemoryBufferStream(std::string *p_buffer)
|
||||
: p_buffer_(p_buffer) {
|
||||
curr_ptr_ = 0;
|
||||
}
|
||||
virtual ~MemoryBufferStream(void) {}
|
||||
virtual size_t Read(void *ptr, size_t size) {
|
||||
utils::Assert(curr_ptr_ <= p_buffer_->length(),
|
||||
"read can not have position excceed buffer length");
|
||||
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
|
||||
if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
|
||||
curr_ptr_ += nread;
|
||||
return nread;
|
||||
}
|
||||
virtual void Write(const void *ptr, size_t size) {
|
||||
if (size == 0) return;
|
||||
if (curr_ptr_ + size > p_buffer_->length()) {
|
||||
p_buffer_->resize(curr_ptr_+size);
|
||||
}
|
||||
std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
|
||||
curr_ptr_ += size;
|
||||
}
|
||||
virtual void Seek(size_t pos) {
|
||||
curr_ptr_ = static_cast<size_t>(pos);
|
||||
}
|
||||
virtual size_t Tell(void) {
|
||||
return curr_ptr_;
|
||||
}
|
||||
virtual bool AtEnd(void) const {
|
||||
return curr_ptr_ == p_buffer_->length();
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief in memory buffer */
|
||||
std::string *p_buffer_;
|
||||
/*! \brief current pointer */
|
||||
size_t curr_ptr_;
|
||||
}; // class MemoryBufferStream
|
||||
} // namespace utils
|
||||
} // namespace rabit
|
||||
#endif // RABIT_INTERNAL_IO_H_
|
||||
386
rabit/include/rabit/internal/rabit-inl.h
Normal file
386
rabit/include/rabit/internal/rabit-inl.h
Normal file
@@ -0,0 +1,386 @@
|
||||
/*!
|
||||
* Copyright (c) 2014-2019 by Contributors
|
||||
* \file rabit-inl.h
|
||||
* \brief implementation of inline template function for rabit interface
|
||||
*
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_INTERNAL_RABIT_INL_H_
|
||||
#define RABIT_INTERNAL_RABIT_INL_H_
|
||||
// use engine for implementation
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "rabit/internal/io.h"
|
||||
#include "rabit/internal/utils.h"
|
||||
#include "rabit/rabit.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
namespace mpi {
|
||||
// template function to translate type to enum indicator
|
||||
template<typename DType>
|
||||
inline DataType GetType(void);
|
||||
template<>
|
||||
inline DataType GetType<char>(void) {
|
||||
return kChar;
|
||||
}
|
||||
template<>
|
||||
inline DataType GetType<unsigned char>(void) {
|
||||
return kUChar;
|
||||
}
|
||||
template<>
|
||||
inline DataType GetType<int>(void) {
|
||||
return kInt;
|
||||
}
|
||||
template<>
|
||||
inline DataType GetType<unsigned int>(void) { // NOLINT(*)
|
||||
return kUInt;
|
||||
}
|
||||
template<>
|
||||
inline DataType GetType<long>(void) { // NOLINT(*)
|
||||
return kLong;
|
||||
}
|
||||
template<>
|
||||
inline DataType GetType<unsigned long>(void) { // NOLINT(*)
|
||||
return kULong;
|
||||
}
|
||||
template<>
|
||||
inline DataType GetType<float>(void) {
|
||||
return kFloat;
|
||||
}
|
||||
template<>
|
||||
inline DataType GetType<double>(void) {
|
||||
return kDouble;
|
||||
}
|
||||
template<>
|
||||
inline DataType GetType<long long>(void) { // NOLINT(*)
|
||||
return kLongLong;
|
||||
}
|
||||
template<>
|
||||
inline DataType GetType<unsigned long long>(void) { // NOLINT(*)
|
||||
return kULongLong;
|
||||
}
|
||||
} // namespace mpi
|
||||
} // namespace engine
|
||||
|
||||
namespace op {
|
||||
struct Max {
|
||||
static const engine::mpi::OpType kType = engine::mpi::kMax;
|
||||
template<typename DType>
|
||||
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
|
||||
if (dst < src) dst = src;
|
||||
}
|
||||
};
|
||||
struct Min {
|
||||
static const engine::mpi::OpType kType = engine::mpi::kMin;
|
||||
template<typename DType>
|
||||
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
|
||||
if (dst > src) dst = src;
|
||||
}
|
||||
};
|
||||
struct Sum {
|
||||
static const engine::mpi::OpType kType = engine::mpi::kSum;
|
||||
template<typename DType>
|
||||
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
|
||||
dst += src;
|
||||
}
|
||||
};
|
||||
struct BitOR {
|
||||
static const engine::mpi::OpType kType = engine::mpi::kBitwiseOR;
|
||||
template<typename DType>
|
||||
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
|
||||
dst |= src;
|
||||
}
|
||||
};
|
||||
template<typename OP, typename DType>
|
||||
inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
|
||||
const DType* src = (const DType*)src_;
|
||||
DType* dst = (DType*)dst_; // NOLINT(*)
|
||||
for (int i = 0; i < len; i++) {
|
||||
OP::Reduce(dst[i], src[i]);
|
||||
}
|
||||
}
|
||||
} // namespace op
|
||||
|
||||
// intialize the rabit engine
|
||||
inline bool Init(int argc, char *argv[]) {
|
||||
return engine::Init(argc, argv);
|
||||
}
|
||||
// finalize the rabit engine
|
||||
inline bool Finalize(void) {
|
||||
return engine::Finalize();
|
||||
}
|
||||
// get the rank of the previous worker in ring topology
|
||||
inline int GetRingPrevRank(void) {
|
||||
return engine::GetEngine()->GetRingPrevRank();
|
||||
}
|
||||
// get the rank of current process
|
||||
inline int GetRank(void) {
|
||||
return engine::GetEngine()->GetRank();
|
||||
}
|
||||
// the the size of the world
|
||||
inline int GetWorldSize(void) {
|
||||
return engine::GetEngine()->GetWorldSize();
|
||||
}
|
||||
// whether rabit is distributed
|
||||
inline bool IsDistributed(void) {
|
||||
return engine::GetEngine()->IsDistributed();
|
||||
}
|
||||
// get the name of current processor
|
||||
inline std::string GetProcessorName(void) {
|
||||
return engine::GetEngine()->GetHost();
|
||||
}
|
||||
// broadcast data to all other nodes from root
|
||||
inline void Broadcast(void *sendrecv_data, size_t size, int root,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
engine::GetEngine()->Broadcast(sendrecv_data, size, root,
|
||||
_file, _line, _caller);
|
||||
}
|
||||
template<typename DType>
|
||||
inline void Broadcast(std::vector<DType> *sendrecv_data, int root,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
size_t size = sendrecv_data->size();
|
||||
Broadcast(&size, sizeof(size), root, _file, _line, _caller);
|
||||
if (sendrecv_data->size() != size) {
|
||||
sendrecv_data->resize(size);
|
||||
}
|
||||
if (size != 0) {
|
||||
Broadcast(&(*sendrecv_data)[0], size * sizeof(DType), root,
|
||||
_file, _line, _caller);
|
||||
}
|
||||
}
|
||||
inline void Broadcast(std::string *sendrecv_data, int root,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
size_t size = sendrecv_data->length();
|
||||
Broadcast(&size, sizeof(size), root, _file, _line, _caller);
|
||||
if (sendrecv_data->length() != size) {
|
||||
sendrecv_data->resize(size);
|
||||
}
|
||||
if (size != 0) {
|
||||
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root,
|
||||
_file, _line, _caller);
|
||||
}
|
||||
}
|
||||
|
||||
// perform inplace Allreduce
|
||||
template<typename OP, typename DType>
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
|
||||
engine::mpi::GetType<DType>(), OP::kType, prepare_fun, prepare_arg,
|
||||
_file, _line, _caller);
|
||||
}
|
||||
|
||||
// C++11 support for lambda prepare function
|
||||
#if DMLC_USE_CXX11
|
||||
inline void InvokeLambda_(void *fun) {
|
||||
(*static_cast<std::function<void()>*>(fun))();
|
||||
}
|
||||
template<typename OP, typename DType>
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
std::function<void()> prepare_fun,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
|
||||
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda_, &prepare_fun,
|
||||
_file, _line, _caller);
|
||||
}
|
||||
|
||||
// Performs inplace Allgather
|
||||
template<typename DType>
|
||||
inline void Allgather(DType *sendrecvbuf,
|
||||
size_t totalSize,
|
||||
size_t beginIndex,
|
||||
size_t sizeNodeSlice,
|
||||
size_t sizePrevSlice,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
engine::GetEngine()->Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType),
|
||||
(beginIndex + sizeNodeSlice) * sizeof(DType),
|
||||
sizePrevSlice * sizeof(DType), _file, _line, _caller);
|
||||
}
|
||||
#endif // C++11
|
||||
|
||||
// print message to the tracker
|
||||
inline void TrackerPrint(const std::string &msg) {
|
||||
engine::GetEngine()->TrackerPrint(msg);
|
||||
}
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
inline void TrackerPrintf(const char *fmt, ...) {
|
||||
const int kPrintBuffer = 1 << 10;
|
||||
std::string msg(kPrintBuffer, '\0');
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||
va_end(args);
|
||||
msg.resize(strlen(msg.c_str()));
|
||||
TrackerPrint(msg);
|
||||
}
|
||||
|
||||
#endif // RABIT_STRICT_CXX98_
|
||||
// load latest check point
|
||||
inline int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model) {
|
||||
return engine::GetEngine()->LoadCheckPoint(global_model, local_model);
|
||||
}
|
||||
// checkpoint the model, meaning we finished a stage of execution
|
||||
inline void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model) {
|
||||
engine::GetEngine()->CheckPoint(global_model, local_model);
|
||||
}
|
||||
// lazy checkpoint the model, only remember the pointer to global_model
|
||||
inline void LazyCheckPoint(const Serializable *global_model) {
|
||||
engine::GetEngine()->LazyCheckPoint(global_model);
|
||||
}
|
||||
// return the version number of currently stored model
|
||||
inline int VersionNumber(void) {
|
||||
return engine::GetEngine()->VersionNumber();
|
||||
}
|
||||
// ---------------------------------
|
||||
// Code to handle customized Reduce
|
||||
// ---------------------------------
|
||||
// function to perform reduction for Reducer
|
||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
||||
inline void ReducerSafe_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
||||
const size_t kUnit = sizeof(DType);
|
||||
const char *psrc = reinterpret_cast<const char*>(src_);
|
||||
char *pdst = reinterpret_cast<char*>(dst_);
|
||||
|
||||
for (int i = 0; i < len_; ++i) {
|
||||
DType tdst, tsrc;
|
||||
// use memcpy to avoid alignment issue
|
||||
std::memcpy(&tdst, pdst + (i * kUnit), sizeof(tdst));
|
||||
std::memcpy(&tsrc, psrc + (i * kUnit), sizeof(tsrc));
|
||||
freduce(tdst, tsrc);
|
||||
std::memcpy(pdst + i * kUnit, &tdst, sizeof(tdst));
|
||||
}
|
||||
}
|
||||
// function to perform reduction for Reducer
|
||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
|
||||
inline void ReducerAlign_(const void *src_, void *dst_,
|
||||
int len_, const MPI::Datatype &dtype) {
|
||||
const DType *psrc = reinterpret_cast<const DType*>(src_);
|
||||
DType *pdst = reinterpret_cast<DType*>(dst_);
|
||||
for (int i = 0; i < len_; ++i) {
|
||||
freduce(pdst[i], psrc[i]);
|
||||
}
|
||||
}
|
||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
|
||||
inline Reducer<DType, freduce>::Reducer(void) {
|
||||
// it is safe to directly use handle for aligned data types
|
||||
if (sizeof(DType) == 8 || sizeof(DType) == 4 || sizeof(DType) == 1) {
|
||||
this->handle_.Init(ReducerAlign_<DType, freduce>, sizeof(DType));
|
||||
} else {
|
||||
this->handle_.Init(ReducerSafe_<DType, freduce>, sizeof(DType));
|
||||
}
|
||||
}
|
||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
|
||||
inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
handle_.Allreduce(sendrecvbuf, sizeof(DType), count, prepare_fun,
|
||||
prepare_arg, _file, _line, _caller);
|
||||
}
|
||||
// function to perform reduction for SerializeReducer
|
||||
template<typename DType>
|
||||
inline void SerializeReducerFunc_(const void *src_, void *dst_,
|
||||
int len_, const MPI::Datatype &dtype) {
|
||||
int nbytes = engine::ReduceHandle::TypeSize(dtype);
|
||||
// temp space
|
||||
for (int i = 0; i < len_; ++i) {
|
||||
DType tsrc, tdst;
|
||||
utils::MemoryFixSizeBuffer fsrc((char*)(src_) + i * nbytes, nbytes); // NOLINT(*)
|
||||
utils::MemoryFixSizeBuffer fdst((char*)(dst_) + i * nbytes, nbytes); // NOLINT(*)
|
||||
tsrc.Load(fsrc);
|
||||
tdst.Load(fdst);
|
||||
// govern const check
|
||||
tdst.Reduce(static_cast<const DType &>(tsrc), nbytes);
|
||||
fdst.Seek(0);
|
||||
tdst.Save(fdst);
|
||||
}
|
||||
}
|
||||
template<typename DType>
|
||||
inline SerializeReducer<DType>::SerializeReducer(void) {
|
||||
handle_.Init(SerializeReducerFunc_<DType>, sizeof(DType));
|
||||
}
|
||||
// closure to call Allreduce
|
||||
template<typename DType>
|
||||
struct SerializeReduceClosure {
|
||||
DType *sendrecvobj;
|
||||
size_t max_nbyte, count;
|
||||
void (*prepare_fun)(void *arg);
|
||||
void *prepare_arg;
|
||||
std::string *p_buffer;
|
||||
// invoke the closure
|
||||
inline void Run(void) {
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
utils::MemoryFixSizeBuffer fs(BeginPtr(*p_buffer) + i * max_nbyte, max_nbyte);
|
||||
sendrecvobj[i].Save(fs);
|
||||
}
|
||||
}
|
||||
inline static void Invoke(void *c) {
|
||||
static_cast<SerializeReduceClosure<DType>*>(c)->Run();
|
||||
}
|
||||
};
|
||||
template<typename DType>
|
||||
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
||||
size_t max_nbyte, size_t count,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
buffer_.resize(max_nbyte * count);
|
||||
// setup closure
|
||||
SerializeReduceClosure<DType> c;
|
||||
c.sendrecvobj = sendrecvobj; c.max_nbyte = max_nbyte; c.count = count;
|
||||
c.prepare_fun = prepare_fun; c.prepare_arg = prepare_arg; c.p_buffer = &buffer_;
|
||||
// invoke here
|
||||
handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count,
|
||||
SerializeReduceClosure<DType>::Invoke, &c,
|
||||
_file, _line, _caller);
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte);
|
||||
sendrecvobj[i].Load(fs);
|
||||
}
|
||||
}
|
||||
|
||||
#if DMLC_USE_CXX11
|
||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)g
|
||||
inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
|
||||
std::function<void()> prepare_fun,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
this->Allreduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun,
|
||||
_file, _line, _caller);
|
||||
}
|
||||
template<typename DType>
|
||||
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
||||
size_t max_nbytes, size_t count,
|
||||
std::function<void()> prepare_fun,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
this->Allreduce(sendrecvobj, max_nbytes, count, InvokeLambda_, &prepare_fun,
|
||||
_file, _line, _caller);
|
||||
}
|
||||
#endif // DMLC_USE_CXX11
|
||||
} // namespace rabit
|
||||
#endif // RABIT_INTERNAL_RABIT_INL_H_
|
||||
536
rabit/include/rabit/internal/socket.h
Normal file
536
rabit/include/rabit/internal/socket.h
Normal file
@@ -0,0 +1,536 @@
|
||||
/*!
|
||||
* Copyright (c) 2014-2019 by Contributors
|
||||
* \file socket.h
|
||||
* \brief this file aims to provide a wrapper of sockets
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_INTERNAL_SOCKET_H_
|
||||
#define RABIT_INTERNAL_SOCKET_H_
|
||||
#if defined(_WIN32)
|
||||
#include <winsock2.h>
|
||||
#include <ws2tcpip.h>
|
||||
#ifdef _MSC_VER
|
||||
#pragma comment(lib, "Ws2_32.lib")
|
||||
#endif // _MSC_VER
|
||||
#else
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
#include <errno.h>
|
||||
#include <unistd.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <netinet/in.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/ioctl.h>
|
||||
#endif // defined(_WIN32)
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include "utils.h"
|
||||
|
||||
#if defined(_WIN32) || defined(__MINGW32__)
|
||||
typedef int ssize_t;
|
||||
#endif // defined(_WIN32) || defined(__MINGW32__)
|
||||
|
||||
#if defined(_WIN32)
|
||||
typedef int sock_size_t;
|
||||
|
||||
static inline int poll(struct pollfd *pfd, int nfds,
|
||||
int timeout) { return WSAPoll ( pfd, nfds, timeout ); }
|
||||
#else
|
||||
#include <sys/poll.h>
|
||||
typedef int SOCKET;
|
||||
typedef size_t sock_size_t;
|
||||
const int INVALID_SOCKET = -1;
|
||||
#endif // defined(_WIN32)
|
||||
|
||||
namespace rabit {
|
||||
namespace utils {
|
||||
/*! \brief data structure for network address */
|
||||
struct SockAddr {
|
||||
sockaddr_in addr;
|
||||
// constructor
|
||||
SockAddr(void) {}
|
||||
SockAddr(const char *url, int port) {
|
||||
this->Set(url, port);
|
||||
}
|
||||
inline static std::string GetHostName(void) {
|
||||
std::string buf; buf.resize(256);
|
||||
utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
|
||||
return std::string(buf.c_str());
|
||||
}
|
||||
/*!
|
||||
* \brief set the address
|
||||
* \param url the url of the address
|
||||
* \param port the port of address
|
||||
*/
|
||||
inline void Set(const char *host, int port) {
|
||||
addrinfo hints;
|
||||
memset(&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_INET;
|
||||
hints.ai_protocol = SOCK_STREAM;
|
||||
addrinfo *res = NULL;
|
||||
int sig = getaddrinfo(host, NULL, &hints, &res);
|
||||
Check(sig == 0 && res != NULL, "cannot obtain address of %s", host);
|
||||
Check(res->ai_family == AF_INET, "Does not support IPv6");
|
||||
memcpy(&addr, res->ai_addr, res->ai_addrlen);
|
||||
addr.sin_port = htons(port);
|
||||
freeaddrinfo(res);
|
||||
}
|
||||
/*! \brief return port of the address*/
|
||||
inline int port(void) const {
|
||||
return ntohs(addr.sin_port);
|
||||
}
|
||||
/*! \return a string representation of the address */
|
||||
inline std::string AddrStr(void) const {
|
||||
std::string buf; buf.resize(256);
|
||||
#ifdef _WIN32
|
||||
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
|
||||
&buf[0], buf.length());
|
||||
#else
|
||||
const char *s = inet_ntop(AF_INET, &addr.sin_addr,
|
||||
&buf[0], buf.length());
|
||||
#endif // _WIN32
|
||||
Assert(s != NULL, "cannot decode address");
|
||||
return std::string(s);
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief base class containing common operations of TCP and UDP sockets
|
||||
*/
|
||||
class Socket {
|
||||
public:
|
||||
/*! \brief the file descriptor of socket */
|
||||
SOCKET sockfd;
|
||||
// default conversion to int
|
||||
inline operator SOCKET() const {
|
||||
return sockfd;
|
||||
}
|
||||
/*!
|
||||
* \return last error of socket operation
|
||||
*/
|
||||
inline static int GetLastError(void) {
|
||||
#ifdef _WIN32
|
||||
return WSAGetLastError();
|
||||
#else
|
||||
return errno;
|
||||
#endif // _WIN32
|
||||
}
|
||||
/*! \return whether last error was would block */
|
||||
inline static bool LastErrorWouldBlock(void) {
|
||||
int errsv = GetLastError();
|
||||
#ifdef _WIN32
|
||||
return errsv == WSAEWOULDBLOCK;
|
||||
#else
|
||||
return errsv == EAGAIN || errsv == EWOULDBLOCK;
|
||||
#endif // _WIN32
|
||||
}
|
||||
/*!
|
||||
* \brief start up the socket module
|
||||
* call this before using the sockets
|
||||
*/
|
||||
inline static void Startup(void) {
|
||||
#ifdef _WIN32
|
||||
WSADATA wsa_data;
|
||||
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
|
||||
Socket::Error("Startup");
|
||||
}
|
||||
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
|
||||
WSACleanup();
|
||||
utils::Error("Could not find a usable version of Winsock.dll\n");
|
||||
}
|
||||
#endif // _WIN32
|
||||
}
|
||||
/*!
|
||||
* \brief shutdown the socket module after use, all sockets need to be closed
|
||||
*/
|
||||
inline static void Finalize(void) {
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif // _WIN32
|
||||
}
|
||||
/*!
|
||||
* \brief set this socket to use non-blocking mode
|
||||
* \param non_block whether set it to be non-block, if it is false
|
||||
* it will set it back to block mode
|
||||
*/
|
||||
inline void SetNonBlock(bool non_block) {
|
||||
#ifdef _WIN32
|
||||
u_long mode = non_block ? 1 : 0;
|
||||
if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
|
||||
Socket::Error("SetNonBlock");
|
||||
}
|
||||
#else
|
||||
int flag = fcntl(sockfd, F_GETFL, 0);
|
||||
if (flag == -1) {
|
||||
Socket::Error("SetNonBlock-1");
|
||||
}
|
||||
if (non_block) {
|
||||
flag |= O_NONBLOCK;
|
||||
} else {
|
||||
flag &= ~O_NONBLOCK;
|
||||
}
|
||||
if (fcntl(sockfd, F_SETFL, flag) == -1) {
|
||||
Socket::Error("SetNonBlock-2");
|
||||
}
|
||||
#endif // _WIN32
|
||||
}
|
||||
/*!
|
||||
* \brief bind the socket to an address
|
||||
* \param addr
|
||||
*/
|
||||
inline void Bind(const SockAddr &addr) {
|
||||
if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
|
||||
sizeof(addr.addr)) == -1) {
|
||||
Socket::Error("Bind");
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief try bind the socket to host, from start_port to end_port
|
||||
* \param start_port starting port number to try
|
||||
* \param end_port ending port number to try
|
||||
* \return the port successfully bind to, return -1 if failed to bind any port
|
||||
*/
|
||||
inline int TryBindHost(int start_port, int end_port) {
|
||||
// TODO(tqchen) add prefix check
|
||||
for (int port = start_port; port < end_port; ++port) {
|
||||
SockAddr addr("0.0.0.0", port);
|
||||
if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
|
||||
sizeof(addr.addr)) == 0) {
|
||||
return port;
|
||||
}
|
||||
#if defined(_WIN32)
|
||||
if (WSAGetLastError() != WSAEADDRINUSE) {
|
||||
Socket::Error("TryBindHost");
|
||||
}
|
||||
#else
|
||||
if (errno != EADDRINUSE) {
|
||||
Socket::Error("TryBindHost");
|
||||
}
|
||||
#endif // defined(_WIN32)
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
/*! \brief get last error code if any */
|
||||
inline int GetSockError(void) const {
|
||||
int error = 0;
|
||||
socklen_t len = sizeof(error);
|
||||
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
|
||||
reinterpret_cast<char*>(&error), &len) != 0) {
|
||||
Error("GetSockError");
|
||||
}
|
||||
return error;
|
||||
}
|
||||
/*! \brief check if anything bad happens */
|
||||
inline bool BadSocket(void) const {
|
||||
if (IsClosed()) return true;
|
||||
int err = GetSockError();
|
||||
if (err == EBADF || err == EINTR) return true;
|
||||
return false;
|
||||
}
|
||||
/*! \brief check if socket is already closed */
|
||||
inline bool IsClosed(void) const {
|
||||
return sockfd == INVALID_SOCKET;
|
||||
}
|
||||
/*! \brief close the socket */
|
||||
inline void Close(void) {
|
||||
if (sockfd != INVALID_SOCKET) {
|
||||
#ifdef _WIN32
|
||||
closesocket(sockfd);
|
||||
#else
|
||||
close(sockfd);
|
||||
#endif
|
||||
sockfd = INVALID_SOCKET;
|
||||
} else {
|
||||
Error("Socket::Close double close the socket or close without create");
|
||||
}
|
||||
}
|
||||
// report an socket error
|
||||
inline static void Error(const char *msg) {
|
||||
int errsv = GetLastError();
|
||||
#ifdef _WIN32
|
||||
utils::Error("Socket %s Error:WSAError-code=%d", msg, errsv);
|
||||
#else
|
||||
utils::Error("Socket %s Error:%s", msg, strerror(errsv));
|
||||
#endif
|
||||
}
|
||||
|
||||
protected:
|
||||
explicit Socket(SOCKET sockfd) : sockfd(sockfd) {
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief a wrapper of TCP socket that hopefully be cross platform
|
||||
*/
|
||||
class TCPSocket : public Socket{
|
||||
public:
|
||||
// constructor
|
||||
TCPSocket(void) : Socket(INVALID_SOCKET) {
|
||||
}
|
||||
explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) {
|
||||
}
|
||||
/*!
|
||||
* \brief enable/disable TCP keepalive
|
||||
* \param keepalive whether to set the keep alive option on
|
||||
*/
|
||||
void SetKeepAlive(bool keepalive) {
|
||||
int opt = static_cast<int>(keepalive);
|
||||
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
|
||||
reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
|
||||
Socket::Error("SetKeepAlive");
|
||||
}
|
||||
}
|
||||
inline void SetLinger(int timeout = 0) {
|
||||
struct linger sl;
|
||||
sl.l_onoff = 1; /* non-zero value enables linger option in kernel */
|
||||
sl.l_linger = timeout; /* timeout interval in seconds */
|
||||
if (setsockopt(sockfd, SOL_SOCKET, SO_LINGER, reinterpret_cast<char*>(&sl), sizeof(sl)) == -1) {
|
||||
Socket::Error("SO_LINGER");
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief create the socket, call this before using socket
|
||||
* \param af domain
|
||||
*/
|
||||
inline void Create(int af = PF_INET) {
|
||||
sockfd = socket(PF_INET, SOCK_STREAM, 0);
|
||||
if (sockfd == INVALID_SOCKET) {
|
||||
Socket::Error("Create");
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief perform listen of the socket
|
||||
* \param backlog backlog parameter
|
||||
*/
|
||||
inline void Listen(int backlog = 16) {
|
||||
listen(sockfd, backlog);
|
||||
}
|
||||
/*! \brief get a new connection */
|
||||
TCPSocket Accept(void) {
|
||||
SOCKET newfd = accept(sockfd, NULL, NULL);
|
||||
if (newfd == INVALID_SOCKET) {
|
||||
Socket::Error("Accept");
|
||||
}
|
||||
return TCPSocket(newfd);
|
||||
}
|
||||
/*!
|
||||
* \brief decide whether the socket is at OOB mark
|
||||
* \return 1 if at mark, 0 if not, -1 if an error occured
|
||||
*/
|
||||
inline int AtMark(void) const {
|
||||
#ifdef _WIN32
|
||||
unsigned long atmark; // NOLINT(*)
|
||||
if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
|
||||
#else
|
||||
int atmark;
|
||||
if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1;
|
||||
#endif // _WIN32
|
||||
return static_cast<int>(atmark);
|
||||
}
|
||||
/*!
|
||||
* \brief connect to an address
|
||||
* \param addr the address to connect to
|
||||
* \return whether connect is successful
|
||||
*/
|
||||
inline bool Connect(const SockAddr &addr) {
|
||||
return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
|
||||
sizeof(addr.addr)) == 0;
|
||||
}
|
||||
/*!
|
||||
* \brief send data using the socket
|
||||
* \param buf the pointer to the buffer
|
||||
* \param len the size of the buffer
|
||||
* \param flags extra flags
|
||||
* \return size of data actually sent
|
||||
* return -1 if error occurs
|
||||
*/
|
||||
inline ssize_t Send(const void *buf_, size_t len, int flag = 0) {
|
||||
const char *buf = reinterpret_cast<const char*>(buf_);
|
||||
return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
|
||||
}
|
||||
/*!
|
||||
* \brief receive data using the socket
|
||||
* \param buf_ the pointer to the buffer
|
||||
* \param len the size of the buffer
|
||||
* \param flags extra flags
|
||||
* \return size of data actually received
|
||||
* return -1 if error occurs
|
||||
*/
|
||||
inline ssize_t Recv(void *buf_, size_t len, int flags = 0) {
|
||||
char *buf = reinterpret_cast<char*>(buf_);
|
||||
return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
|
||||
}
|
||||
/*!
|
||||
* \brief peform block write that will attempt to send all data out
|
||||
* can still return smaller than request when error occurs
|
||||
* \param buf the pointer to the buffer
|
||||
* \param len the size of the buffer
|
||||
* \return size of data actually sent
|
||||
*/
|
||||
inline size_t SendAll(const void *buf_, size_t len) {
|
||||
const char *buf = reinterpret_cast<const char*>(buf_);
|
||||
size_t ndone = 0;
|
||||
while (ndone < len) {
|
||||
ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0);
|
||||
if (ret == -1) {
|
||||
if (LastErrorWouldBlock()) return ndone;
|
||||
Socket::Error("SendAll");
|
||||
}
|
||||
buf += ret;
|
||||
ndone += ret;
|
||||
}
|
||||
return ndone;
|
||||
}
|
||||
/*!
|
||||
* \brief peforma block read that will attempt to read all data
|
||||
* can still return smaller than request when error occurs
|
||||
* \param buf_ the buffer pointer
|
||||
* \param len length of data to recv
|
||||
* \return size of data actually sent
|
||||
*/
|
||||
inline size_t RecvAll(void *buf_, size_t len) {
|
||||
char *buf = reinterpret_cast<char*>(buf_);
|
||||
size_t ndone = 0;
|
||||
while (ndone < len) {
|
||||
ssize_t ret = recv(sockfd, buf,
|
||||
static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
|
||||
if (ret == -1) {
|
||||
if (LastErrorWouldBlock()) return ndone;
|
||||
Socket::Error("RecvAll");
|
||||
}
|
||||
if (ret == 0) return ndone;
|
||||
buf += ret;
|
||||
ndone += ret;
|
||||
}
|
||||
return ndone;
|
||||
}
|
||||
/*!
|
||||
* \brief send a string over network
|
||||
* \param str the string to be sent
|
||||
*/
|
||||
inline void SendStr(const std::string &str) {
|
||||
int len = static_cast<int>(str.length());
|
||||
utils::Assert(this->SendAll(&len, sizeof(len)) == sizeof(len),
|
||||
"error during send SendStr");
|
||||
if (len != 0) {
|
||||
utils::Assert(this->SendAll(str.c_str(), str.length()) == str.length(),
|
||||
"error during send SendStr");
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief recv a string from network
|
||||
* \param out_str the string to receive
|
||||
*/
|
||||
inline void RecvStr(std::string *out_str) {
|
||||
int len;
|
||||
utils::Assert(this->RecvAll(&len, sizeof(len)) == sizeof(len),
|
||||
"error during send RecvStr");
|
||||
out_str->resize(len);
|
||||
if (len != 0) {
|
||||
utils::Assert(this->RecvAll(&(*out_str)[0], len) == out_str->length(),
|
||||
"error during send SendStr");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief helper data structure to perform poll */
|
||||
struct PollHelper {
|
||||
public:
|
||||
/*!
|
||||
* \brief add file descriptor to watch for read
|
||||
* \param fd file descriptor to be watched
|
||||
*/
|
||||
inline void WatchRead(SOCKET fd) {
|
||||
auto& pfd = fds[fd];
|
||||
pfd.fd = fd;
|
||||
pfd.events |= POLLIN;
|
||||
}
|
||||
/*!
|
||||
* \brief add file descriptor to watch for write
|
||||
* \param fd file descriptor to be watched
|
||||
*/
|
||||
inline void WatchWrite(SOCKET fd) {
|
||||
auto& pfd = fds[fd];
|
||||
pfd.fd = fd;
|
||||
pfd.events |= POLLOUT;
|
||||
}
|
||||
/*!
|
||||
* \brief add file descriptor to watch for exception
|
||||
* \param fd file descriptor to be watched
|
||||
*/
|
||||
inline void WatchException(SOCKET fd) {
|
||||
auto& pfd = fds[fd];
|
||||
pfd.fd = fd;
|
||||
pfd.events |= POLLPRI;
|
||||
}
|
||||
/*!
|
||||
* \brief Check if the descriptor is ready for read
|
||||
* \param fd file descriptor to check status
|
||||
*/
|
||||
inline bool CheckRead(SOCKET fd) const {
|
||||
const auto& pfd = fds.find(fd);
|
||||
return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
|
||||
}
|
||||
/*!
|
||||
* \brief Check if the descriptor is ready for write
|
||||
* \param fd file descriptor to check status
|
||||
*/
|
||||
inline bool CheckWrite(SOCKET fd) const {
|
||||
const auto& pfd = fds.find(fd);
|
||||
return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
|
||||
}
|
||||
/*!
|
||||
* \brief Check if the descriptor has any exception
|
||||
* \param fd file descriptor to check status
|
||||
*/
|
||||
inline bool CheckExcept(SOCKET fd) const {
|
||||
const auto& pfd = fds.find(fd);
|
||||
return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0);
|
||||
}
|
||||
/*!
|
||||
* \brief wait for exception event on a single descriptor
|
||||
* \param fd the file descriptor to wait the event for
|
||||
* \param timeout the timeout counter, can be negative, which means wait until the event happen
|
||||
* \return 1 if success, 0 if timeout, and -1 if error occurs
|
||||
*/
|
||||
inline static int WaitExcept(SOCKET fd, long timeout = -1) { // NOLINT(*)
|
||||
pollfd pfd;
|
||||
pfd.fd = fd;
|
||||
pfd.events = POLLPRI;
|
||||
return poll(&pfd, 1, timeout);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief peform poll on the set defined, read, write, exception
|
||||
* \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
|
||||
* \return
|
||||
*/
|
||||
inline void Poll(long timeout = -1) { // NOLINT(*)
|
||||
std::vector<pollfd> fdset;
|
||||
fdset.reserve(fds.size());
|
||||
for (auto kv : fds) {
|
||||
fdset.push_back(kv.second);
|
||||
}
|
||||
int ret = poll(fdset.data(), fdset.size(), timeout);
|
||||
if (ret == -1) {
|
||||
Socket::Error("Poll");
|
||||
} else {
|
||||
for (auto& pfd : fdset) {
|
||||
auto revents = pfd.revents & pfd.events;
|
||||
if (!revents) {
|
||||
fds.erase(pfd.fd);
|
||||
} else {
|
||||
fds[pfd.fd].events = revents;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<SOCKET, pollfd> fds;
|
||||
};
|
||||
} // namespace utils
|
||||
} // namespace rabit
|
||||
#endif // RABIT_INTERNAL_SOCKET_H_
|
||||
87
rabit/include/rabit/internal/thread_local.h
Normal file
87
rabit/include/rabit/internal/thread_local.h
Normal file
@@ -0,0 +1,87 @@
|
||||
/*!
|
||||
* Copyright (c) 2015 by Contributors
|
||||
* \file thread_local.h
|
||||
* \brief Common utility for thread local storage.
|
||||
*/
|
||||
#ifndef RABIT_INTERNAL_THREAD_LOCAL_H_
|
||||
#define RABIT_INTERNAL_THREAD_LOCAL_H_
|
||||
|
||||
#include "../include/dmlc/base.h"
|
||||
|
||||
#if DMLC_ENABLE_STD_THREAD
|
||||
#include <mutex>
|
||||
#endif // DMLC_ENABLE_STD_THREAD
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace rabit {
|
||||
|
||||
// macro hanlding for threadlocal variables
|
||||
#ifdef __GNUC__
|
||||
#define MX_TREAD_LOCAL __thread
|
||||
#elif __STDC_VERSION__ >= 201112L
|
||||
#define MX_TREAD_LOCAL _Thread_local
|
||||
#elif defined(_MSC_VER)
|
||||
#define MX_TREAD_LOCAL __declspec(thread)
|
||||
#endif // __GNUC__
|
||||
|
||||
#ifndef MX_TREAD_LOCAL
|
||||
#message("Warning: Threadlocal is not enabled");
|
||||
#endif // MX_TREAD_LOCAL
|
||||
|
||||
/*!
|
||||
* \brief A threadlocal store to store threadlocal variables.
|
||||
* Will return a thread local singleton of type T
|
||||
* \tparam T the type we like to store
|
||||
*/
|
||||
template<typename T>
|
||||
class ThreadLocalStore {
|
||||
public:
|
||||
/*! \return get a thread local singleton */
|
||||
static T* Get() {
|
||||
static MX_TREAD_LOCAL T* ptr = nullptr;
|
||||
if (ptr == nullptr) {
|
||||
ptr = new T();
|
||||
Singleton()->RegisterDelete(ptr);
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief constructor */
|
||||
ThreadLocalStore() {}
|
||||
/*! \brief destructor */
|
||||
~ThreadLocalStore() {
|
||||
for (size_t i = 0; i < data_.size(); ++i) {
|
||||
delete data_[i];
|
||||
}
|
||||
}
|
||||
/*! \return singleton of the store */
|
||||
static ThreadLocalStore<T> *Singleton() {
|
||||
static ThreadLocalStore<T> inst;
|
||||
return &inst;
|
||||
}
|
||||
/*!
|
||||
* \brief register str for internal deletion
|
||||
* \param str the string pointer
|
||||
*/
|
||||
void RegisterDelete(T *str) {
|
||||
#if DMLC_ENABLE_STD_THREAD
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
data_.push_back(str);
|
||||
lock.unlock();
|
||||
#else
|
||||
data_.push_back(str);
|
||||
#endif // DMLC_ENABLE_STD_THREAD
|
||||
}
|
||||
|
||||
#if DMLC_ENABLE_STD_THREAD
|
||||
/*! \brief internal mutex */
|
||||
std::mutex mutex_;
|
||||
#endif // DMLC_ENABLE_STD_THREAD
|
||||
/*!\brief internal data */
|
||||
std::vector<T*> data_;
|
||||
};
|
||||
} // namespace rabit
|
||||
#endif // RABIT_INTERNAL_THREAD_LOCAL_H_
|
||||
41
rabit/include/rabit/internal/timer.h
Normal file
41
rabit/include/rabit/internal/timer.h
Normal file
@@ -0,0 +1,41 @@
|
||||
/*!
|
||||
* Copyright (c) 2014-2019 by Contributors
|
||||
* \file timer.h
|
||||
* \brief This file defines the utils for timing
|
||||
* \author Tianqi Chen, Nacho, Tianyi
|
||||
*/
|
||||
#ifndef RABIT_INTERNAL_TIMER_H_
|
||||
#define RABIT_INTERNAL_TIMER_H_
|
||||
#include <time.h>
|
||||
#ifdef __MACH__
|
||||
#include <mach/clock.h>
|
||||
#include <mach/mach.h>
|
||||
#endif // __MACH__
|
||||
#include "./utils.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace utils {
|
||||
/*!
|
||||
* \brief return time in seconds, not cross platform, avoid to use this in most places
|
||||
*/
|
||||
inline double GetTime(void) {
|
||||
#ifdef __MACH__
|
||||
clock_serv_t cclock;
|
||||
mach_timespec_t mts;
|
||||
host_get_clock_service(mach_host_self(), CALENDAR_CLOCK, &cclock);
|
||||
utils::Check(clock_get_time(cclock, &mts) == 0, "failed to get time");
|
||||
mach_port_deallocate(mach_task_self(), cclock);
|
||||
return static_cast<double>(mts.tv_sec) + static_cast<double>(mts.tv_nsec) * 1e-9;
|
||||
#else
|
||||
#if defined(__unix__) || defined(__linux__)
|
||||
timespec ts;
|
||||
utils::Check(clock_gettime(CLOCK_REALTIME, &ts) == 0, "failed to get time");
|
||||
return static_cast<double>(ts.tv_sec) + static_cast<double>(ts.tv_nsec) * 1e-9;
|
||||
#else
|
||||
return static_cast<double>(time(NULL));
|
||||
#endif // defined(__unix__) || defined(__linux__)
|
||||
#endif // __MACH__
|
||||
}
|
||||
} // namespace utils
|
||||
} // namespace rabit
|
||||
#endif // RABIT_INTERNAL_TIMER_H_
|
||||
219
rabit/include/rabit/internal/utils.h
Normal file
219
rabit/include/rabit/internal/utils.h
Normal file
@@ -0,0 +1,219 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file utils.h
|
||||
* \brief simple utils to support the code
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_INTERNAL_UTILS_H_
|
||||
#define RABIT_INTERNAL_UTILS_H_
|
||||
|
||||
#include <rabit/base.h>
|
||||
#include <string.h>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#include "dmlc/io.h"
|
||||
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
#include <cstdarg>
|
||||
#endif // RABIT_STRICT_CXX98_
|
||||
|
||||
#if !defined(__GNUC__) || defined(__FreeBSD__)
|
||||
#define fopen64 std::fopen
|
||||
#endif // !defined(__GNUC__) || defined(__FreeBSD__)
|
||||
|
||||
#ifdef _MSC_VER
|
||||
// NOTE: sprintf_s is not equivalent to snprintf,
|
||||
// they are equivalent when success, which is sufficient for our case
|
||||
#define snprintf sprintf_s
|
||||
#define vsnprintf vsprintf_s
|
||||
|
||||
#else
|
||||
|
||||
#ifdef _FILE_OFFSET_BITS
|
||||
#if _FILE_OFFSET_BITS == 32
|
||||
#pragma message("Warning: FILE OFFSET BITS defined to be 32 bit")
|
||||
#endif // _FILE_OFFSET_BITS == 32
|
||||
#endif // _FILE_OFFSET_BITS
|
||||
|
||||
#ifdef __APPLE__
|
||||
#define off64_t off_t
|
||||
#define fopen64 std::fopen
|
||||
#endif // __APPLE__
|
||||
|
||||
extern "C" {
|
||||
#include <sys/types.h>
|
||||
}
|
||||
#endif // _MSC_VER
|
||||
|
||||
#ifdef _MSC_VER
|
||||
typedef unsigned char uint8_t;
|
||||
typedef unsigned __int16 uint16_t;
|
||||
typedef unsigned __int32 uint32_t;
|
||||
typedef unsigned __int64 uint64_t;
|
||||
typedef __int64 int64_t;
|
||||
#else
|
||||
#include <inttypes.h>
|
||||
#endif // _MSC_VER
|
||||
|
||||
namespace rabit {
|
||||
/*! \brief namespace for helper utils of the project */
|
||||
namespace utils {
|
||||
|
||||
/*! \brief error message buffer length */
|
||||
const int kPrintBuffer = 1 << 12;
|
||||
|
||||
/* \brief Case-insensitive string comparison */
|
||||
inline int CompareStringsCaseInsensitive(const char* s1, const char* s2) {
|
||||
#ifdef _MSC_VER
|
||||
return _stricmp(s1, s2);
|
||||
#else // _MSC_VER
|
||||
return strcasecmp(s1, s2);
|
||||
#endif // _MSC_VER
|
||||
}
|
||||
|
||||
/* \brief parse config string too bool*/
|
||||
inline bool StringToBool(const char* s) {
|
||||
return CompareStringsCaseInsensitive(s, "true") == 0 || atoi(s) != 0;
|
||||
}
|
||||
|
||||
#ifndef RABIT_CUSTOMIZE_MSG_
|
||||
/*!
|
||||
* \brief handling of Assert error, caused by inappropriate input
|
||||
* \param msg error message
|
||||
*/
|
||||
inline void HandleAssertError(const char *msg) {
|
||||
fprintf(stderr,
|
||||
"AssertError:%s, rabit is configured to keep process running\n", msg);
|
||||
throw dmlc::Error(msg);
|
||||
}
|
||||
/*!
|
||||
* \brief handling of Check error, caused by inappropriate input
|
||||
* \param msg error message
|
||||
*/
|
||||
inline void HandleCheckError(const char *msg) {
|
||||
fprintf(stderr, "%s, rabit is configured to keep process running\n", msg);
|
||||
throw dmlc::Error(msg);
|
||||
}
|
||||
inline void HandlePrint(const char *msg) {
|
||||
printf("%s", msg);
|
||||
}
|
||||
|
||||
inline void HandleLogInfo(const char *fmt, ...) {
|
||||
std::string msg(kPrintBuffer, '\0');
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||
va_end(args);
|
||||
fprintf(stdout, "%s", msg.c_str());
|
||||
fflush(stdout);
|
||||
}
|
||||
#else
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
// include declarations, some one must implement this
|
||||
void HandleAssertError(const char *msg);
|
||||
void HandleCheckError(const char *msg);
|
||||
void HandlePrint(const char *msg);
|
||||
#endif // RABIT_STRICT_CXX98_
|
||||
#endif // RABIT_CUSTOMIZE_MSG_
|
||||
#ifdef RABIT_STRICT_CXX98_
|
||||
// these function pointers are to be assigned
|
||||
extern "C" void (*Printf)(const char *fmt, ...);
|
||||
extern "C" int (*SPrintf)(char *buf, size_t size, const char *fmt, ...);
|
||||
extern "C" void (*Assert)(int exp, const char *fmt, ...);
|
||||
extern "C" void (*Check)(int exp, const char *fmt, ...);
|
||||
extern "C" void (*Error)(const char *fmt, ...);
|
||||
#else
|
||||
/*! \brief printf, prints messages to the console */
|
||||
inline void Printf(const char *fmt, ...) {
|
||||
std::string msg(kPrintBuffer, '\0');
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||
va_end(args);
|
||||
HandlePrint(msg.c_str());
|
||||
}
|
||||
/*! \brief portable version of snprintf */
|
||||
inline int SPrintf(char *buf, size_t size, const char *fmt, ...) {
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
int ret = vsnprintf(buf, size, fmt, args);
|
||||
va_end(args);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*! \brief assert a condition is true, use this to handle debug information */
|
||||
inline void Assert(bool exp, const char *fmt, ...) {
|
||||
if (!exp) {
|
||||
std::string msg(kPrintBuffer, '\0');
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||
va_end(args);
|
||||
HandleAssertError(msg.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
/*!\brief same as assert, but this is intended to be used as a message for users */
|
||||
inline void Check(bool exp, const char *fmt, ...) {
|
||||
if (!exp) {
|
||||
std::string msg(kPrintBuffer, '\0');
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||
va_end(args);
|
||||
HandleCheckError(msg.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief report error message, same as check */
|
||||
inline void Error(const char *fmt, ...) {
|
||||
{
|
||||
std::string msg(kPrintBuffer, '\0');
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||
va_end(args);
|
||||
HandleCheckError(msg.c_str());
|
||||
}
|
||||
}
|
||||
#endif // RABIT_STRICT_CXX98_
|
||||
|
||||
/*! \brief replace fopen, report error when the file open fails */
|
||||
inline std::FILE *FopenCheck(const char *fname, const char *flag) {
|
||||
std::FILE *fp = fopen64(fname, flag);
|
||||
Check(fp != NULL, "can not open file \"%s\"\n", fname);
|
||||
return fp;
|
||||
}
|
||||
} // namespace utils
|
||||
// easy utils that can be directly accessed in xgboost
|
||||
/*! \brief get the beginning address of a vector */
|
||||
template<typename T>
|
||||
inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*)
|
||||
if (vec.size() == 0) {
|
||||
return NULL;
|
||||
} else {
|
||||
return &vec[0];
|
||||
}
|
||||
}
|
||||
/*! \brief get the beginning address of a vector */
|
||||
template<typename T>
|
||||
inline const T *BeginPtr(const std::vector<T> &vec) { // NOLINT(*)
|
||||
if (vec.size() == 0) {
|
||||
return NULL;
|
||||
} else {
|
||||
return &vec[0];
|
||||
}
|
||||
}
|
||||
inline char* BeginPtr(std::string &str) { // NOLINT(*)
|
||||
if (str.length() == 0) return NULL;
|
||||
return &str[0];
|
||||
}
|
||||
inline const char* BeginPtr(const std::string &str) {
|
||||
if (str.length() == 0) return NULL;
|
||||
return &str[0];
|
||||
}
|
||||
} // namespace rabit
|
||||
#endif // RABIT_INTERNAL_UTILS_H_
|
||||
460
rabit/include/rabit/rabit.h
Normal file
460
rabit/include/rabit/rabit.h
Normal file
@@ -0,0 +1,460 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file rabit.h
|
||||
* \brief This file defines rabit's Allreduce/Broadcast interface
|
||||
* The rabit engine contains the actual implementation
|
||||
* Code that only uses this header can also be compiled with MPI Allreduce (non fault-tolerant),
|
||||
*
|
||||
* rabit.h and serializable.h is all what the user needs to use the rabit interface
|
||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||
*/
|
||||
#ifndef RABIT_RABIT_H_ // NOLINT(*)
|
||||
#define RABIT_RABIT_H_ // NOLINT(*)
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// whether or not use c++11 support
|
||||
#ifndef DMLC_USE_CXX11
|
||||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || defined(_MSC_VER)
|
||||
#define DMLC_USE_CXX11 1
|
||||
#else
|
||||
#define DMLC_USE_CXX11 (__cplusplus >= 201103L)
|
||||
#endif // defined(__GXX_EXPERIMENTAL_CXX0X__) || defined(_MSC_VER)
|
||||
#endif // DMLC_USE_CXX11
|
||||
|
||||
// keeps rabit api caller signature
|
||||
#ifndef RABIT_API_CALLER_SIGNATURE
|
||||
#define RABIT_API_CALLER_SIGNATURE
|
||||
|
||||
#if (defined(__GNUC__) && !defined(__clang__))
|
||||
#define _FILE __builtin_FILE()
|
||||
#define _LINE __builtin_LINE()
|
||||
#define _CALLER __builtin_FUNCTION()
|
||||
#else
|
||||
#define _FILE "N/A"
|
||||
#define _LINE -1
|
||||
#define _CALLER "N/A"
|
||||
#endif // (defined(__GNUC__) && !defined(__clang__))
|
||||
|
||||
#endif // RABIT_API_CALLER_SIGNATURE
|
||||
|
||||
// optionally support of lambda functions in C++11, if available
|
||||
#if DMLC_USE_CXX11
|
||||
#include <functional>
|
||||
#endif // C++11
|
||||
// engine definition of rabit, defines internal implementation
|
||||
// to use rabit interface, there is no need to read engine.h
|
||||
// rabit.h and serializable.h are enough to use the interface
|
||||
#include "./internal/engine.h"
|
||||
|
||||
/*! \brief rabit namespace */
|
||||
namespace rabit {
|
||||
/*!
|
||||
* \brief defines stream used in rabit
|
||||
* see definition of Stream in dmlc/io.h
|
||||
*/
|
||||
typedef dmlc::Stream Stream;
|
||||
/*!
|
||||
* \brief defines serializable objects used in rabit
|
||||
* see definition of Serializable in dmlc/io.h
|
||||
*/
|
||||
typedef dmlc::Serializable Serializable;
|
||||
|
||||
/*!
|
||||
* \brief reduction operators namespace
|
||||
*/
|
||||
namespace op {
|
||||
/*!
|
||||
* \class rabit::op::Max
|
||||
* \brief maximum reduction operator
|
||||
*/
|
||||
struct Max;
|
||||
/*!
|
||||
* \class rabit::op::Min
|
||||
* \brief minimum reduction operator
|
||||
*/
|
||||
struct Min;
|
||||
/*!
|
||||
* \class rabit::op::Sum
|
||||
* \brief sum reduction operator
|
||||
*/
|
||||
struct Sum;
|
||||
/*!
|
||||
* \class rabit::op::BitOR
|
||||
* \brief bitwise OR reduction operator
|
||||
*/
|
||||
struct BitOR;
|
||||
} // namespace op
|
||||
/*!
|
||||
* \brief initializes rabit, call this once at the beginning of your program
|
||||
* \param argc number of arguments in argv
|
||||
* \param argv the array of input arguments
|
||||
* \return true if initialized successfully, otherwise false
|
||||
*/
|
||||
inline bool Init(int argc, char *argv[]);
|
||||
/*!
|
||||
* \brief finalizes the rabit engine, call this function after you finished with all the jobs
|
||||
* \return true if finalized successfully, otherwise false
|
||||
*/
|
||||
inline bool Finalize();
|
||||
/*! \brief gets rank of the current process
|
||||
* \return rank number of worker*/
|
||||
inline int GetRank();
|
||||
/*! \brief gets total number of processes
|
||||
* \return total world size*/
|
||||
inline int GetWorldSize();
|
||||
/*! \brief whether rabit env is in distributed mode
|
||||
* \return is distributed*/
|
||||
inline bool IsDistributed();
|
||||
|
||||
/*! \brief gets processor's name
|
||||
* \return processor name*/
|
||||
inline std::string GetProcessorName();
|
||||
/*!
|
||||
* \brief prints the msg to the tracker,
|
||||
* this function can be used to communicate progress information to
|
||||
* the user who monitors the tracker
|
||||
* \param msg the message to be printed
|
||||
*/
|
||||
inline void TrackerPrint(const std::string &msg);
|
||||
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
/*!
|
||||
* \brief prints the msg to the tracker, this function may not be available
|
||||
* in very strict c++98 compilers, though it usually is.
|
||||
* this function can be used to communicate progress information to
|
||||
* the user who monitors the tracker
|
||||
* \param fmt the format string
|
||||
*/
|
||||
inline void TrackerPrintf(const char *fmt, ...);
|
||||
#endif // RABIT_STRICT_CXX98_
|
||||
/*!
|
||||
* \brief broadcasts a memory region to every node from the root
|
||||
*
|
||||
* Example: int a = 1; Broadcast(&a, sizeof(a), root);
|
||||
* \param sendrecv_data the pointer to the send/receive buffer,
|
||||
* \param size the data size
|
||||
* \param root the process root
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
inline void Broadcast(void *sendrecv_data, size_t size, int root,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
|
||||
/*!
|
||||
* \brief broadcasts an std::vector<DType> to every node from root
|
||||
* \param sendrecv_data the pointer to send/receive vector,
|
||||
* for the receiver, the vector does not need to be pre-allocated
|
||||
* \param root the process root
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
* \tparam DType the data type stored in the vector, has to be a simple data type
|
||||
* that can be directly transmitted by sending the sizeof(DType)
|
||||
*/
|
||||
template<typename DType>
|
||||
inline void Broadcast(std::vector<DType> *sendrecv_data, int root,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
/*!
|
||||
* \brief broadcasts a std::string to every node from the root
|
||||
* \param sendrecv_data the pointer to the send/receive buffer,
|
||||
* for the receiver, the vector does not need to be pre-allocated
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
* \param root the process root
|
||||
*/
|
||||
inline void Broadcast(std::string *sendrecv_data, int root,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
/*!
|
||||
* \brief performs in-place Allreduce on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
*
|
||||
* Example Usage: the following code does an Allreduce and outputs the sum as the result
|
||||
* \code{.cpp}
|
||||
* vector<int> data(10);
|
||||
* ...
|
||||
* Allreduce<op::Sum>(&data[0], data.size());
|
||||
* ...
|
||||
* \endcode
|
||||
*
|
||||
* \param sendrecvbuf buffer for both sending and receiving data
|
||||
* \param count number of elements to be reduced
|
||||
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to pass into the lazy preprocessing function
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
* \tparam OP see namespace op, reduce operator
|
||||
* \tparam DType data type
|
||||
*/
|
||||
template<typename OP, typename DType>
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
void (*prepare_fun)(void *) = NULL,
|
||||
void *prepare_arg = NULL,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
|
||||
/*!
|
||||
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
template<typename DType>
|
||||
inline void Allgather(DType *sendrecvbuf_,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
|
||||
// C++11 support for lambda prepare function
|
||||
#if DMLC_USE_CXX11
|
||||
/*!
|
||||
* \brief performs in-place Allreduce, on sendrecvbuf
|
||||
* with a prepare function specified by a lambda function
|
||||
*
|
||||
* Example Usage:
|
||||
* \code{.cpp}
|
||||
* // the following code does an Allreduce and outputs the sum as the result
|
||||
* vector<int> data(10);
|
||||
* ...
|
||||
* Allreduce<op::Sum>(&data[0], data.size(), [&]() {
|
||||
* for (int i = 0; i < 10; ++i) {
|
||||
* data[i] = i;
|
||||
* }
|
||||
* });
|
||||
* ...
|
||||
* \endcode
|
||||
* \param sendrecvbuf buffer for both sending and receiving data
|
||||
* \param count number of elements to be reduced
|
||||
* \param prepare_fun Lazy lambda preprocessing function, prepare_fun() will be invoked
|
||||
* by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
* \tparam OP see namespace op, reduce operator
|
||||
* \tparam DType data type
|
||||
*/
|
||||
template<typename OP, typename DType>
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
std::function<void()> prepare_fun,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
#endif // C++11
|
||||
/*!
|
||||
* \brief loads the latest check point
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller needs to guarantee that the global_model
|
||||
* is the same in every node
|
||||
* \param local_model pointer to the local model that is specific to the current node/rank
|
||||
* this can be NULL when no local model is needed
|
||||
*
|
||||
* \return the version number of the check point loaded
|
||||
* if returned version == 0, this means no model has been CheckPointed
|
||||
* the p_model is not touched, users should do the necessary initialization by themselves
|
||||
*
|
||||
* \code{.cpp}
|
||||
* // Example usage code of LoadCheckPoint
|
||||
* int iter = rabit::LoadCheckPoint(&model);
|
||||
* if (iter == 0) model.InitParameters();
|
||||
* for (i = iter; i < max_iter; ++i) {
|
||||
* // do many things, include allreduce
|
||||
* rabit::CheckPoint(model);
|
||||
* }
|
||||
* \endcode
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
inline int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model = NULL);
|
||||
/*!
|
||||
* \brief checkpoints the model, meaning a stage of execution has finished.
|
||||
* every time we call check point, a version number will be increased by one
|
||||
*
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller needs to guarantee that the global_model
|
||||
* is the same in every node
|
||||
* \param local_model pointer to the local model that is specific to the current node/rank
|
||||
* this can be NULL when no local state is needed
|
||||
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
|
||||
* bring replication cost in the CheckPoint function. global_model does not need explicit replication.
|
||||
* So, only CheckPoint with the global_model if possible
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
inline void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model = NULL);
|
||||
/*!
|
||||
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||
* when certain condition is met (see detailed explanation).
|
||||
*
|
||||
* This is a "lazy" checkpoint such that only the pointer to the global_model is
|
||||
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
||||
* The global_model must remain unchanged until the last call of Allreduce/Broadcast in the current version finishes.
|
||||
* In other words, the global_model model can be changed only between the last call of
|
||||
* Allreduce/Broadcast and LazyCheckPoint, both in the same version
|
||||
*
|
||||
* For example, suppose the calling sequence is:
|
||||
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint/(or can be CheckPoint)
|
||||
*
|
||||
* Then the user MUST only change the global_model in code3.
|
||||
*
|
||||
* The use of LazyCheckPoint instead of CheckPoint will improve the efficiency of the program.
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller needs to guarantee that the global_model
|
||||
* is the same in every node
|
||||
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||
*/
|
||||
inline void LazyCheckPoint(const Serializable *global_model);
|
||||
/*!
|
||||
* \return version number of the current stored model,
|
||||
* which means how many calls to CheckPoint we made so far
|
||||
* \sa LoadCheckPoint, CheckPoint
|
||||
*/
|
||||
inline int VersionNumber();
|
||||
// ----- extensions that allow customized reducer ------
|
||||
// helper class to do customized reduce, user do not need to know the type
|
||||
namespace engine {
|
||||
class ReduceHandle;
|
||||
} // namespace engine
|
||||
/*!
|
||||
* \brief template class to make customized reduce and all reduce easy
|
||||
* Do not use reducer directly in the function you call Finalize,
|
||||
* because the destructor can execute after Finalize
|
||||
* \tparam DType data type that to be reduced
|
||||
* \tparam freduce the customized reduction function
|
||||
* DType must be a struct, with no pointer
|
||||
*/
|
||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
|
||||
class Reducer {
|
||||
public:
|
||||
Reducer();
|
||||
/*!
|
||||
* \brief customized in-place all reduce operation
|
||||
* \param sendrecvbuf the in place send-recv buffer
|
||||
* \param count number of elements to be reduced
|
||||
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to pass into the lazy preprocessing function
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
void (*prepare_fun)(void *) = NULL,
|
||||
void *prepare_arg = NULL,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
#if DMLC_USE_CXX11
|
||||
/*!
|
||||
* \brief customized in-place all reduce operation, with lambda function as preprocessor
|
||||
* \param sendrecvbuf pointer to the array of objects to be reduced
|
||||
* \param count number of elements to be reduced
|
||||
* \param prepare_fun lambda function executed to prepare the data, if necessary
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
std::function<void()> prepare_fun,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
#endif // DMLC_USE_CXX11
|
||||
|
||||
private:
|
||||
/*! \brief function handle to do reduce */
|
||||
engine::ReduceHandle handle_;
|
||||
};
|
||||
/*!
|
||||
* \brief template class to make customized reduce,
|
||||
* this class defines complex reducer handles all the data structure that can be
|
||||
* serialized/deserialized into fixed size buffer
|
||||
* Do not use reducer directly in the function you call Finalize, because the destructor can execute after Finalize
|
||||
*
|
||||
* \tparam DType data type that to be reduced, DType must contain the following functions:
|
||||
* \tparam freduce the customized reduction function
|
||||
* (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &src, size_t max_nbyte)
|
||||
*/
|
||||
template<typename DType>
|
||||
class SerializeReducer {
|
||||
public:
|
||||
SerializeReducer();
|
||||
/*!
|
||||
* \brief customized in-place all reduce operation
|
||||
* \param sendrecvobj pointer to the array of objects to be reduced
|
||||
* \param max_nbyte maximum amount of memory needed to serialize each object
|
||||
* this includes budget limit for intermediate and final result
|
||||
* \param count number of elements to be reduced
|
||||
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf.
|
||||
* If the result of Allreduce can be recovered directly, then the prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to pass into the lazy preprocessing function
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
inline void Allreduce(DType *sendrecvobj,
|
||||
size_t max_nbyte, size_t count,
|
||||
void (*prepare_fun)(void *) = NULL,
|
||||
void *prepare_arg = NULL,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
// C++11 support for lambda prepare function
|
||||
#if DMLC_USE_CXX11
|
||||
/*!
|
||||
* \brief customized in-place all reduce operation, with lambda function as preprocessor
|
||||
* \param sendrecvobj pointer to the array of objects to be reduced
|
||||
* \param max_nbyte maximum amount of memory needed to serialize each object
|
||||
* this includes budget limit for intermediate and final result
|
||||
* \param count number of elements to be reduced
|
||||
* \param prepare_fun lambda function executed to prepare the data, if necessary
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
inline void Allreduce(DType *sendrecvobj,
|
||||
size_t max_nbyte, size_t count,
|
||||
std::function<void()> prepare_fun,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
#endif // DMLC_USE_CXX11
|
||||
|
||||
private:
|
||||
/*! \brief function handle to do reduce */
|
||||
engine::ReduceHandle handle_;
|
||||
/*! \brief temporal buffer used to do reduce*/
|
||||
std::string buffer_;
|
||||
};
|
||||
} // namespace rabit
|
||||
// implementation of template functions
|
||||
#include "./internal/rabit-inl.h"
|
||||
#endif // RABIT_RABIT_H_ // NOLINT(*)
|
||||
26
rabit/include/rabit/serializable.h
Normal file
26
rabit/include/rabit/serializable.h
Normal file
@@ -0,0 +1,26 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file serializable.h
|
||||
* \brief defines serializable interface of rabit
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_SERIALIZABLE_H_
|
||||
#define RABIT_SERIALIZABLE_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "rabit/internal/utils.h"
|
||||
|
||||
namespace rabit {
|
||||
/*!
|
||||
* \brief defines stream used in rabit
|
||||
* see definition of Stream in dmlc/io.h
|
||||
*/
|
||||
typedef dmlc::Stream Stream;
|
||||
/*!
|
||||
* \brief defines serializable objects used in rabit
|
||||
* see definition of Serializable in dmlc/io.h
|
||||
*/
|
||||
typedef dmlc::Serializable Serializable;
|
||||
|
||||
} // namespace rabit
|
||||
#endif // RABIT_SERIALIZABLE_H_
|
||||
Reference in New Issue
Block a user