238 lines
8.2 KiB
C++

/*!
* Copyright (c) 2014 by Contributors
* \file rabit.h
* \brief This file defines rabit's Allreduce/Broadcast interface
* The rabit engine contains the actual implementation
* Code that only uses this header can also be compiled with MPI Allreduce (non fault-tolerant),
*
* rabit.h and serializable.h is all what the user needs to use the rabit interface
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#ifndef RABIT_RABIT_H_ // NOLINT(*)
#define RABIT_RABIT_H_ // NOLINT(*)
#include <string>
#include <vector>
#include <functional>
// engine definition of rabit, defines internal implementation
// to use rabit interface, there is no need to read engine.h
// rabit.h and serializable.h are enough to use the interface
#include "./internal/engine.h"
/*! \brief rabit namespace */
namespace rabit {
/*!
* \brief defines stream used in rabit
* see definition of Stream in dmlc/io.h
*/
using Stream = dmlc::Stream;
/*!
* \brief defines serializable objects used in rabit
* see definition of Serializable in dmlc/io.h
*/
using Serializable = dmlc::Serializable;
/*!
* \brief reduction operators namespace
*/
namespace op {
/*!
* \class rabit::op::Max
* \brief maximum reduction operator
*/
struct Max;
/*!
* \class rabit::op::Min
* \brief minimum reduction operator
*/
struct Min;
/*!
* \class rabit::op::Sum
* \brief sum reduction operator
*/
struct Sum;
/*!
* \class rabit::op::BitAND
* \brief bitwise AND reduction operator
*/
struct BitAND;
/*!
* \class rabit::op::BitOR
* \brief bitwise OR reduction operator
*/
struct BitOR;
/*!
* \class rabit::op::BitXOR
* \brief bitwise XOR reduction operator
*/
struct BitXOR;
} // namespace op
/*!
* \brief initializes rabit, call this once at the beginning of your program
* \param argc number of arguments in argv
* \param argv the array of input arguments
* \return true if initialized successfully, otherwise false
*/
inline bool Init(int argc, char *argv[]);
/*!
* \brief finalizes the rabit engine, call this function after you finished with all the jobs
* \return true if finalized successfully, otherwise false
*/
inline bool Finalize();
/*! \brief gets rank of the current process
* \return rank number of worker*/
inline int GetRank();
/*! \brief gets total number of processes
* \return total world size*/
inline int GetWorldSize();
/*! \brief whether rabit env is in distributed mode
* \return is distributed*/
inline bool IsDistributed();
/*! \brief gets processor's name
* \return processor name*/
inline std::string GetProcessorName();
/*!
* \brief prints the msg to the tracker,
* this function can be used to communicate progress information to
* the user who monitors the tracker
* \param msg the message to be printed
*/
inline void TrackerPrint(const std::string &msg);
#ifndef RABIT_STRICT_CXX98_
/*!
* \brief prints the msg to the tracker, this function may not be available
* in very strict c++98 compilers, though it usually is.
* this function can be used to communicate progress information to
* the user who monitors the tracker
* \param fmt the format string
*/
inline void TrackerPrintf(const char *fmt, ...);
#endif // RABIT_STRICT_CXX98_
/*!
* \brief broadcasts a memory region to every node from the root
*
* Example: int a = 1; Broadcast(&a, sizeof(a), root);
* \param sendrecv_data the pointer to the send/receive buffer,
* \param size the data size
* \param root the process root
*/
inline void Broadcast(void *sendrecv_data, size_t size, int root);
/*!
* \brief broadcasts an std::vector<DType> to every node from root
* \param sendrecv_data the pointer to send/receive vector,
* for the receiver, the vector does not need to be pre-allocated
* \param root the process root
* \tparam DType the data type stored in the vector, has to be a simple data type
* that can be directly transmitted by sending the sizeof(DType)
*/
template<typename DType>
inline void Broadcast(std::vector<DType> *sendrecv_data, int root);
/*!
* \brief broadcasts a std::string to every node from the root
* \param sendrecv_data the pointer to the send/receive buffer,
* for the receiver, the vector does not need to be pre-allocated
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
* \param root the process root
*/
inline void Broadcast(std::string *sendrecv_data, int root);
/*!
* \brief performs in-place Allreduce on sendrecvbuf
* this function is NOT thread-safe
*
* Example Usage: the following code does an Allreduce and outputs the sum as the result
* \code{.cpp}
* vector<int> data(10);
* ...
* Allreduce<op::Sum>(&data[0], data.size());
* ...
* \endcode
*
* \param sendrecvbuf buffer for both sending and receiving data
* \param count number of elements to be reduced
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to pass into the lazy preprocessing function
* \tparam OP see namespace op, reduce operator
* \tparam DType data type
*/
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *) = nullptr,
void *prepare_arg = nullptr);
/*!
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
*/
template<typename DType>
inline void Allgather(DType *sendrecvbuf_,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice);
// C++11 support for lambda prepare function
#if DMLC_USE_CXX11
/*!
* \brief performs in-place Allreduce, on sendrecvbuf
* with a prepare function specified by a lambda function
*
* Example Usage:
* \code{.cpp}
* // the following code does an Allreduce and outputs the sum as the result
* vector<int> data(10);
* ...
* Allreduce<op::Sum>(&data[0], data.size(), [&]() {
* for (int i = 0; i < 10; ++i) {
* data[i] = i;
* }
* });
* ...
* \endcode
* \param sendrecvbuf buffer for both sending and receiving data
* \param count number of elements to be reduced
* \param prepare_fun Lazy lambda preprocessing function, prepare_fun() will be invoked
* by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \tparam OP see namespace op, reduce operator
* \tparam DType data type
*/
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count,
std::function<void()> prepare_fun);
#endif // C++11
/*!
* \brief deprecated, planned for removal after checkpoing from JVM package is removed.
*/
inline int LoadCheckPoint();
/*!
* \brief deprecated, planned for removal after checkpoing from JVM package is removed.
*/
inline void CheckPoint();
/*!
* \return version number of the current stored model,
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
inline int VersionNumber();
} // namespace rabit
// implementation of template functions
#include "./internal/rabit-inl.h"
#endif // RABIT_RABIT_H_ // NOLINT(*)