refactor: librabit

This commit is contained in:
tqchen
2016-02-27 10:14:26 -08:00
parent 73b6e9bbd0
commit 7479791f6a
33 changed files with 412 additions and 801 deletions

View File

@@ -1,7 +0,0 @@
Library Header Files
====
* This folder contains all the header needed to use the library
* To use it, add the "include" folder to the search path of the compiler
* User only needs to know [rabit.h](rabit.h) and [rabit_serializable.h](rabit_serializable.h) in order to use the library
* Folder [rabit](rabit) contains headers for internal engine and template's implementation
* Not all .h files in the project are in the "include" folder, .h files that are internally used by the library remain at [src](../src)

135
include/rabit/c_api.h Normal file
View File

@@ -0,0 +1,135 @@
/*!
* Copyright by Contributors
* \file rabit_c_api.h
* \author Tianqi Chen
* \brief a C style API of rabit.
*/
#ifndef RABIT_C_API_H_
#define RABIT_C_API_H_
#ifdef __cplusplus
#define RABIT_EXTERN_C extern "C"
#endif
#if defined(_MSC_VER) || defined(_WIN32)
#define RABIT_DLL RABIT_EXTERN_C __declspec(dllexport)
#else
#define RABIT_DLL RABIT_EXTERN_C
#endif
// manually define unsign long
typedef unsigned long rbt_ulong; // NOLINT(*)
/*!
* \brief intialize the rabit module,
* call this once before using anything
* The additional arguments is not necessary.
* Usually rabit will detect settings
* from environment variables.
* \param argc number of arguments in argv
* \param argv the array of input arguments
*/
RABIT_DLL void RabitInit(int argc, char *argv[]);
/*!
* \brief finalize the rabit engine,
* call this function after you finished all jobs.
*/
RABIT_DLL void RabitFinalize();
/*! \brief get rank of current process */
RABIT_DLL int RabitGetRank();
/*! \brief get total number of process */
RABIT_DLL int RabitGetWorldSize();
/*!
* \brief print the msg to the tracker,
* this function can be used to communicate the information of the progress to
* the user who monitors the tracker
* \param msg the message to be printed
*/
RABIT_DLL void RabitTrackerPrint(const char *msg);
/*!
* \brief get name of processor
* \param out_name hold output string
* \param out_len hold length of output string
* \param max_len maximum buffer length of input
*/
RABIT_DLL void RabitGetProcessorName(char *out_name,
rbt_ulong *out_len,
rbt_ulong max_len);
/*!
* \brief broadcast an memory region to all others from root
*
* Example: int a = 1; Broadcast(&a, sizeof(a), root);
* \param sendrecv_data the pointer to send or recive buffer,
* \param size the size of the data
* \param root the root of process
*/
RABIT_DLL void RabitBroadcast(void *sendrecv_data,
rbt_ulong size, int root);
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe
*
* Example Usage: the following code gives sum of the result
* vector<int> data(10);
* ...
* Allreduce<op::Sum>(&data[0], data.size());
* ...
* \param sendrecvbuf buffer for both sending and recving data
* \param count number of elements to be reduced
* \param enum_dtype the enumeration of data type, see rabit::engine::mpi::DataType in engine.h of rabit include
* \param enum_op the enumeration of operation type, see rabit::engine::mpi::OpType in engine.h of rabit
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
*/
RABIT_DLL void RabitAllreduce(void *sendrecvbuf,
size_t count,
int enum_dtype,
int enum_op,
void (*prepare_fun)(void *arg),
void *prepare_arg);
/*!
* \brief load latest check point
* \param out_global_model hold output of serialized global_model
* \param out_global_len the output length of serialized global model
* \param out_local_model hold output of serialized local_model, can be NULL
* \param out_local_len the output length of serialized local model, can be NULL
*
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* nothing will be touched
*/
RABIT_DLL int RabitLoadCheckPoint(char **out_global_model,
rbt_ulong *out_global_len,
char **out_local_model,
rbt_ulong *out_local_len);
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param global_model hold content of serialized global_model
* \param global_len the content length of serialized global model
* \param local_model hold content of serialized local_model, can be NULL
* \param local_len the content length of serialized local model, can be NULL
*
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
* So only CheckPoint with global_model if possible
*/
RABIT_DLL void RabitCheckPoint(const char *global_model,
rbt_ulong global_len,
const char *local_model,
rbt_ulong local_len);
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far
*/
RABIT_DLL int RabitVersionNumber();
#endif // RABIT_C_API_H_

View File

@@ -4,10 +4,10 @@
* \brief This file defines the core interface of rabit library
* \author Tianqi Chen, Nacho, Tianyi
*/
#ifndef RABIT_ENGINE_H_
#define RABIT_ENGINE_H_
#ifndef RABIT_INTERNAL_ENGINE_H_
#define RABIT_INTERNAL_ENGINE_H_
#include <string>
#include "../rabit_serializable.h"
#include "../serializable.h"
namespace MPI {
/*! \brief MPI data type just to be compatible with MPI reduce function*/
@@ -241,7 +241,8 @@ class ReduceHandle {
* \param prepare_arg argument used to pass into the lazy preprocessing function
*/
void Allreduce(void *sendrecvbuf,
size_t type_nbytes, size_t count,
size_t type_nbytes,
size_t count,
IEngine::PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL);
/*! \return the number of bytes occupied by the type */
@@ -259,4 +260,4 @@ class ReduceHandle {
};
} // namespace engine
} // namespace rabit
#endif // RABIT_ENGINE_H_
#endif // RABIT_INTERNAL_ENGINE_H_

View File

@@ -4,15 +4,15 @@
* \brief utilities with different serializable implementations
* \author Tianqi Chen
*/
#ifndef RABIT_IO_H_
#define RABIT_IO_H_
#ifndef RABIT_INTERNAL_IO_H_
#define RABIT_INTERNAL_IO_H_
#include <cstdio>
#include <vector>
#include <cstring>
#include <string>
#include <algorithm>
#include "./utils.h"
#include "../rabit_serializable.h"
#include "../serializable.h"
namespace rabit {
namespace utils {
@@ -103,4 +103,4 @@ struct MemoryBufferStream : public SeekStream {
}; // class MemoryBufferStream
} // namespace utils
} // namespace rabit
#endif // RABIT_IO_H_
#endif // RABIT_INTERNAL_IO_H_

View File

@@ -5,8 +5,8 @@
*
* \author Tianqi Chen
*/
#ifndef RABIT_RABIT_INL_H_
#define RABIT_RABIT_INL_H_
#ifndef RABIT_INTERNAL_RABIT_INL_H_
#define RABIT_INTERNAL_RABIT_INL_H_
// use engine for implementation
#include <vector>
#include <string>
@@ -325,4 +325,4 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
}
#endif
} // namespace rabit
#endif // RABIT_RABIT_INL_H_
#endif // RABIT_INTERNAL_RABIT_INL_H_

View File

@@ -4,8 +4,8 @@
* \brief This file defines the utils for timing
* \author Tianqi Chen, Nacho, Tianyi
*/
#ifndef RABIT_TIMER_H_
#define RABIT_TIMER_H_
#ifndef RABIT_INTERNAL_TIMER_H_
#define RABIT_INTERNAL_TIMER_H_
#include <time.h>
#ifdef __MACH__
#include <mach/clock.h>
@@ -38,4 +38,4 @@ inline double GetTime(void) {
}
} // namespace utils
} // namespace rabit
#endif // RABIT_TIMER_H_
#endif // RABIT_INTERNAL_TIMER_H_

View File

@@ -4,8 +4,8 @@
* \brief simple utils to support the code
* \author Tianqi Chen
*/
#ifndef RABIT_UTILS_H_
#define RABIT_UTILS_H_
#ifndef RABIT_INTERNAL_UTILS_H_
#define RABIT_INTERNAL_UTILS_H_
#define _CRT_SECURE_NO_WARNINGS
#include <cstdio>
#include <string>
@@ -188,4 +188,4 @@ inline const char* BeginPtr(const std::string &str) {
return &str[0];
}
} // namespace rabit
#endif // RABIT_UTILS_H_
#endif // RABIT_INTERNAL_UTILS_H_

View File

@@ -22,15 +22,24 @@
#if DMLC_USE_CXX11
#include <functional>
#endif // C++11
// contains definition of Serializable
#include "./rabit_serializable.h"
// engine definition of rabit, defines internal implementation
// to use rabit interface, there is no need to read engine.h
// rabit.h and serializable.h are enough to use the interface
#include "./rabit/engine.h"
#include "./internal/engine.h"
/*! \brief rabit namespace */
namespace rabit {
/*!
* \brief defines stream used in rabit
* see definition of Stream in dmlc/io.h
*/
typedef dmlc::Stream Stream;
/*!
* \brief defines serializable objects used in rabit
* see definition of Serializable in dmlc/io.h
*/
typedef dmlc::Serializable Serializable;
/*!
* \brief reduction operators namespace
*/
@@ -65,16 +74,16 @@ inline void Init(int argc, char *argv[]);
/*!
* \brief finalizes the rabit engine, call this function after you finished with all the jobs
*/
inline void Finalize(void);
inline void Finalize();
/*! \brief gets rank of the current process */
inline int GetRank(void);
inline int GetRank();
/*! \brief gets total number of processes */
inline int GetWorldSize(void);
inline int GetWorldSize();
/*! \brief whether rabit env is in distributed mode */
inline bool IsDistributed(void);
inline bool IsDistributed();
/*! \brief gets processor's name */
inline std::string GetProcessorName(void);
inline std::string GetProcessorName();
/*!
* \brief prints the msg to the tracker,
* this function can be used to communicate progress information to
@@ -241,7 +250,7 @@ inline void LazyCheckPoint(const Serializable *global_model);
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
inline int VersionNumber(void);
inline int VersionNumber();
// ----- extensions that allow customized reducer ------
// helper class to do customized reduce, user do not need to know the type
namespace engine {
@@ -258,7 +267,7 @@ class ReduceHandle;
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
class Reducer {
public:
Reducer(void);
Reducer();
/*!
* \brief customized in-place all reduce operation
* \param sendrecvbuf the in place send-recv buffer
@@ -299,7 +308,7 @@ class Reducer {
template<typename DType>
class SerializeReducer {
public:
SerializeReducer(void);
SerializeReducer();
/*!
* \brief customized in-place all reduce operation
* \param sendrecvobj pointer to the array of objects to be reduced
@@ -338,5 +347,6 @@ class SerializeReducer {
};
} // namespace rabit
// implementation of template functions
#include "./rabit/rabit-inl.h"
#include "./internal/
rabit-inl.h"
#endif // RABIT_RABIT_H_ // NOLINT(*)

View File

@@ -1,6 +1,6 @@
/*!
* Copyright (c) 2014 by Contributors
* \file rabit_serializable.h
* \file serializable.h
* \brief defines serializable interface of rabit
* \author Tianqi Chen
*/
@@ -8,8 +8,8 @@
#define RABIT_SERIALIZABLE_H_
#include <vector>
#include <string>
#include "./rabit/utils.h"
#include "./dmlc/io.h"
#include "./internal/utils.h"
#include "../dmlc/io.h"
namespace rabit {
/*!