[RABIT] fix rabit in local mode

This commit is contained in:
tqchen 2016-01-12 21:34:26 -08:00
parent 05b958c178
commit 112d866dc9
2 changed files with 14 additions and 12 deletions

View File

@ -21,7 +21,7 @@ namespace engine {
/*! \brief interface of core Allreduce engine */ /*! \brief interface of core Allreduce engine */
class IEngine { class IEngine {
public: public:
/*! /*!
* \brief Preprocessing function, that is called before AllReduce, * \brief Preprocessing function, that is called before AllReduce,
* used to prepare the data used by AllReduce * used to prepare the data used by AllReduce
* \param arg additional possible argument used to invoke the preprocessor * \param arg additional possible argument used to invoke the preprocessor
@ -41,6 +41,8 @@ class IEngine {
typedef void (ReduceFunction) (const void *src, typedef void (ReduceFunction) (const void *src,
void *dst, int count, void *dst, int count,
const MPI::Datatype &dtype); const MPI::Datatype &dtype);
/*! \brief virtual destructor */
virtual ~IEngine() {}
/*! /*!
* \brief performs in-place Allreduce, on sendrecvbuf * \brief performs in-place Allreduce, on sendrecvbuf
* this function is NOT thread-safe * this function is NOT thread-safe
@ -83,14 +85,14 @@ class IEngine {
* \return the version number of the model loaded * \return the version number of the model loaded
* if returned version == 0, this means no model has been CheckPointed * if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, users should do necessary initialization by themselves * the p_model is not touched, users should do necessary initialization by themselves
* *
* Common usage example: * Common usage example:
* int iter = rabit::LoadCheckPoint(&model); * int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters(); * if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) { * for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce * do many things, include allreduce
* rabit::CheckPoint(model); * rabit::CheckPoint(model);
* } * }
* *
* \sa CheckPoint, VersionNumber * \sa CheckPoint, VersionNumber
*/ */
@ -99,7 +101,7 @@ class IEngine {
/*! /*!
* \brief checkpoints the model, meaning a stage of execution was finished * \brief checkpoints the model, meaning a stage of execution was finished
* every time we call check point, a version number increases by ones * every time we call check point, a version number increases by ones
* *
* \param global_model pointer to the globally shared model/state * \param global_model pointer to the globally shared model/state
* when calling this function, the caller needs to guarantee that the global_model * when calling this function, the caller needs to guarantee that the global_model
* is the same in every node * is the same in every node
@ -117,16 +119,16 @@ class IEngine {
/*! /*!
* \brief This function can be used to replace CheckPoint for global_model only, * \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met (see detailed explanation). * when certain condition is met (see detailed explanation).
* *
* This is a "lazy" checkpoint such that only the pointer to global_model is * This is a "lazy" checkpoint such that only the pointer to global_model is
* remembered and no memory copy is taken. To use this function, the user MUST ensure that: * remembered and no memory copy is taken. To use this function, the user MUST ensure that:
* The global_model must remain unchanged until the last call of Allreduce/Broadcast in the current version finishes. * The global_model must remain unchanged until the last call of Allreduce/Broadcast in the current version finishes.
* In other words, global_model can be changed only between the last call of * In other words, global_model can be changed only between the last call of
* Allreduce/Broadcast and LazyCheckPoint in the current version * Allreduce/Broadcast and LazyCheckPoint in the current version
* *
* For example, suppose the calling sequence is: * For example, suppose the calling sequence is:
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint * LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
* *
* If the user can only change global_model in code3, then LazyCheckPoint can be used to * If the user can only change global_model in code3, then LazyCheckPoint can be used to
* improve the efficiency of the program. * improve the efficiency of the program.
* \param global_model pointer to the globally shared model/state * \param global_model pointer to the globally shared model/state
@ -189,14 +191,14 @@ enum DataType {
}; };
} // namespace mpi } // namespace mpi
/*! /*!
* \brief perform in-place Allreduce, on sendrecvbuf * \brief perform in-place Allreduce, on sendrecvbuf
* this is an internal function used by rabit to be able to compile with MPI * this is an internal function used by rabit to be able to compile with MPI
* do not use this function directly * do not use this function directly
* \param sendrecvbuf buffer for both sending and receiving data * \param sendrecvbuf buffer for both sending and receiving data
* \param type_nbytes the number of bytes the type has * \param type_nbytes the number of bytes the type has
* \param count number of elements to be reduced * \param count number of elements to be reduced
* \param reducer reduce function * \param reducer reduce function
* \param dtype the data type * \param dtype the data type
* \param op the reduce operator type * \param op the reduce operator type
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg) * \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf_. * will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf_.
@ -229,7 +231,7 @@ class ReduceHandle {
*/ */
void Init(IEngine::ReduceFunction redfunc, size_t type_nbytes); void Init(IEngine::ReduceFunction redfunc, size_t type_nbytes);
/*! /*!
* \brief customized in-place all reduce operation * \brief customized in-place all reduce operation
* \param sendrecvbuf the in place send-recv buffer * \param sendrecvbuf the in place send-recv buffer
* \param type_n4bytes size of the type, in terms of 4bytes * \param type_n4bytes size of the type, in terms of 4bytes
* \param count number of elements to send * \param count number of elements to send

View File

@ -79,7 +79,7 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_,
PreprocFunction prepare_fun, PreprocFunction prepare_fun,
void *prepare_arg) { void *prepare_arg) {
// skip action in single node // skip action in single node
if (world_size == 1) { if (world_size == 1 || world_size == -1) {
if (prepare_fun != NULL) prepare_fun(prepare_arg); if (prepare_fun != NULL) prepare_fun(prepare_arg);
return; return;
} }