add wrapper
This commit is contained in:
parent
61626aaf85
commit
9a4a81f100
3
.gitignore
vendored
3
.gitignore
vendored
@ -32,4 +32,5 @@
|
|||||||
*.exe
|
*.exe
|
||||||
*.txt
|
*.txt
|
||||||
*tmp*
|
*tmp*
|
||||||
doc
|
doc
|
||||||
|
*.rabit
|
||||||
|
|||||||
17
Makefile
17
Makefile
@ -1,20 +1,23 @@
|
|||||||
export CC = gcc
|
export CC = gcc
|
||||||
export CXX = g++
|
export CXX = g++
|
||||||
export MPICXX = mpicxx
|
export MPICXX = mpicxx
|
||||||
export LDFLAGS=
|
export LDFLAGS= -Llib
|
||||||
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -Iinclude
|
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -Iinclude
|
||||||
|
|
||||||
# build path
|
# build path
|
||||||
BPATH=.
|
BPATH=.
|
||||||
# objectives that makes up rabit library
|
# objectives that makes up rabit library
|
||||||
MPIOBJ= $(BPATH)/engine_mpi.o
|
MPIOBJ= $(BPATH)/engine_mpi.o
|
||||||
OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/engine_empty.o $(BPATH)/engine_mock.o
|
OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/engine_empty.o $(BPATH)/engine_mock.o\
|
||||||
|
$(BPATH)/rabit_wrapper.o
|
||||||
|
SLIB= wrapper/librabit_wrapper.so
|
||||||
ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a lib/librabit_mock.a
|
ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a lib/librabit_mock.a
|
||||||
HEADERS=src/*.h include/*.h include/rabit/*.h
|
HEADERS=src/*.h include/*.h include/rabit/*.h
|
||||||
.PHONY: clean all install mpi
|
.PHONY: clean all install mpi python
|
||||||
|
|
||||||
all: lib/librabit.a lib/librabit_mock.a
|
all: lib/librabit.a lib/librabit_mock.a $(SLIB)
|
||||||
mpi: lib/librabit_mpi.a
|
mpi: lib/librabit_mpi.a
|
||||||
|
python: wrapper/librabit_wrapper.so
|
||||||
|
|
||||||
$(BPATH)/allreduce_base.o: src/allreduce_base.cc $(HEADERS)
|
$(BPATH)/allreduce_base.o: src/allreduce_base.cc $(HEADERS)
|
||||||
$(BPATH)/engine.o: src/engine.cc $(HEADERS)
|
$(BPATH)/engine.o: src/engine.cc $(HEADERS)
|
||||||
@ -27,6 +30,9 @@ lib/librabit.a: $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/e
|
|||||||
lib/librabit_mock.a: $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine_mock.o
|
lib/librabit_mock.a: $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine_mock.o
|
||||||
lib/librabit_empty.a: $(BPATH)/engine_empty.o
|
lib/librabit_empty.a: $(BPATH)/engine_empty.o
|
||||||
lib/librabit_mpi.a: $(MPIOBJ)
|
lib/librabit_mpi.a: $(MPIOBJ)
|
||||||
|
# wrapper code
|
||||||
|
$(BPATH)/rabit_wrapper.o: wrapper/rabit_wrapper.cc
|
||||||
|
wrapper/librabit_wrapper.so: $(BPATH)/rabit_wrapper.o lib/librabit.a
|
||||||
|
|
||||||
$(OBJ) :
|
$(OBJ) :
|
||||||
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
||||||
@ -37,5 +43,8 @@ $(MPIOBJ) :
|
|||||||
$(ALIB):
|
$(ALIB):
|
||||||
ar cr $@ $+
|
ar cr $@ $+
|
||||||
|
|
||||||
|
$(SLIB) :
|
||||||
|
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^)
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
$(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~
|
$(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~
|
||||||
|
|||||||
164
wrapper/rabit_wrapper.cc
Normal file
164
wrapper/rabit_wrapper.cc
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
// implementations in ctypes
|
||||||
|
#define _CRT_SECURE_NO_WARNINGS
|
||||||
|
#define _CRT_SECURE_NO_DEPRECATE
|
||||||
|
|
||||||
|
#include <rabit.h>
|
||||||
|
#include <cstring>
|
||||||
|
#include <string>
|
||||||
|
#include "./rabit_wrapper.h"
|
||||||
|
namespace rabit {
|
||||||
|
namespace wrapper {
|
||||||
|
// helper use to avoid BitOR operator
|
||||||
|
template<typename OP, typename DType>
|
||||||
|
struct FHelper {
|
||||||
|
inline static void
|
||||||
|
Allreduce(DType *senrecvbuf_,
|
||||||
|
size_t count,
|
||||||
|
void (*prepare_fun)(void *arg),
|
||||||
|
void *prepare_arg) {
|
||||||
|
rabit::Allreduce<OP>(senrecvbuf_, count,
|
||||||
|
prepare_fun, prepare_arg);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
template<typename DType>
|
||||||
|
struct FHelper<op::BitOR, DType> {
|
||||||
|
inline static void
|
||||||
|
Allreduce(DType *senrecvbuf_,
|
||||||
|
size_t count,
|
||||||
|
void (*prepare_fun)(void *arg),
|
||||||
|
void *prepare_arg) {
|
||||||
|
utils::Error("DataType does not support bitwise or operation");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
template<typename OP>
|
||||||
|
inline void Allreduce_(void *sendrecvbuf_,
|
||||||
|
size_t count,
|
||||||
|
engine::mpi::DataType enum_dtype,
|
||||||
|
void (*prepare_fun)(void *arg),
|
||||||
|
void *prepare_arg) {
|
||||||
|
using namespace engine::mpi;
|
||||||
|
switch (enum_dtype) {
|
||||||
|
case kChar:
|
||||||
|
rabit::Allreduce<OP>
|
||||||
|
(static_cast<char*>(sendrecvbuf_),
|
||||||
|
count, prepare_fun, prepare_arg);
|
||||||
|
return;
|
||||||
|
case kUChar:
|
||||||
|
rabit::Allreduce<OP>
|
||||||
|
(static_cast<unsigned char*>(sendrecvbuf_),
|
||||||
|
count, prepare_fun, prepare_arg);
|
||||||
|
return;
|
||||||
|
case kInt:
|
||||||
|
rabit::Allreduce<OP>
|
||||||
|
(static_cast<int*>(sendrecvbuf_),
|
||||||
|
count, prepare_fun, prepare_arg);
|
||||||
|
return;
|
||||||
|
case kUInt:
|
||||||
|
rabit::Allreduce<OP>
|
||||||
|
(static_cast<unsigned*>(sendrecvbuf_),
|
||||||
|
count, prepare_fun, prepare_arg);
|
||||||
|
return;
|
||||||
|
case kLong:
|
||||||
|
rabit::Allreduce<OP>
|
||||||
|
(static_cast<long*>(sendrecvbuf_),
|
||||||
|
count, prepare_fun, prepare_arg);
|
||||||
|
return;
|
||||||
|
case kULong:
|
||||||
|
rabit::Allreduce<OP>
|
||||||
|
(static_cast<unsigned long*>(sendrecvbuf_),
|
||||||
|
count, prepare_fun, prepare_arg);
|
||||||
|
return;
|
||||||
|
case kFloat:
|
||||||
|
FHelper<OP, float>::Allreduce
|
||||||
|
(static_cast<float*>(sendrecvbuf_),
|
||||||
|
count, prepare_fun, prepare_arg);
|
||||||
|
return;
|
||||||
|
case kDouble:
|
||||||
|
FHelper<OP, double>::Allreduce
|
||||||
|
(static_cast<double*>(sendrecvbuf_),
|
||||||
|
count, prepare_fun, prepare_arg);
|
||||||
|
return;
|
||||||
|
default: utils::Error("unknown data_type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inline void Allreduce(void *sendrecvbuf,
|
||||||
|
size_t count,
|
||||||
|
engine::mpi::DataType enum_dtype,
|
||||||
|
engine::mpi::OpType enum_op,
|
||||||
|
void (*prepare_fun)(void *arg),
|
||||||
|
void *prepare_arg) {
|
||||||
|
using namespace engine::mpi;
|
||||||
|
switch (enum_op) {
|
||||||
|
case kMax:
|
||||||
|
Allreduce_<op::Max>
|
||||||
|
(sendrecvbuf,
|
||||||
|
count, enum_dtype,
|
||||||
|
prepare_fun, prepare_arg);
|
||||||
|
return;
|
||||||
|
case kMin:
|
||||||
|
Allreduce_<op::Min>
|
||||||
|
(sendrecvbuf,
|
||||||
|
count, enum_dtype,
|
||||||
|
prepare_fun, prepare_arg);
|
||||||
|
return;
|
||||||
|
case kSum:
|
||||||
|
Allreduce_<op::Sum>
|
||||||
|
(sendrecvbuf,
|
||||||
|
count, enum_dtype,
|
||||||
|
prepare_fun, prepare_arg);
|
||||||
|
return;
|
||||||
|
case kBitwiseOR:
|
||||||
|
Allreduce_<op::BitOR>
|
||||||
|
(sendrecvbuf,
|
||||||
|
count, enum_dtype,
|
||||||
|
prepare_fun, prepare_arg);
|
||||||
|
return;
|
||||||
|
default: utils::Error("unknown enum_op");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace wrapper
|
||||||
|
} // namespace rabit
|
||||||
|
extern "C" {
|
||||||
|
void RabitInit(int argc, char *argv[]) {
|
||||||
|
rabit::Init(argc, argv);
|
||||||
|
}
|
||||||
|
void RabitFinalize(void) {
|
||||||
|
rabit::Finalize();
|
||||||
|
}
|
||||||
|
int RabitGetRank(void) {
|
||||||
|
return rabit::GetRank();
|
||||||
|
}
|
||||||
|
int RabitGetWorldSize(void) {
|
||||||
|
return rabit::GetWorldSize();
|
||||||
|
}
|
||||||
|
void RabitTrackerPrint(const char *msg) {
|
||||||
|
std::string m(msg);
|
||||||
|
rabit::TrackerPrint(m);
|
||||||
|
}
|
||||||
|
void RabitGetProcessorName(char *out_name,
|
||||||
|
rbt_ulong *out_len,
|
||||||
|
rbt_ulong max_len) {
|
||||||
|
std::string s = rabit::GetProcessorName();
|
||||||
|
if (s.length() > max_len) {
|
||||||
|
s.resize(max_len - 1);
|
||||||
|
}
|
||||||
|
strcpy(out_name, s.c_str());
|
||||||
|
*out_len = static_cast<rbt_ulong>(s.length());
|
||||||
|
}
|
||||||
|
void RabitBroadcast(void *sendrecv_data,
|
||||||
|
rbt_ulong size, int root) {
|
||||||
|
rabit::Broadcast(sendrecv_data, size, root);
|
||||||
|
}
|
||||||
|
void RabitAllreduce(void *sendrecvbuf,
|
||||||
|
size_t count,
|
||||||
|
int enum_dtype,
|
||||||
|
int enum_op,
|
||||||
|
void (*prepare_fun)(void *arg),
|
||||||
|
void *prepare_arg) {
|
||||||
|
rabit::wrapper::Allreduce
|
||||||
|
(sendrecvbuf, count,
|
||||||
|
static_cast<rabit::engine::mpi::DataType>(enum_dtype),
|
||||||
|
static_cast<rabit::engine::mpi::OpType>(enum_op),
|
||||||
|
prepare_fun, prepare_arg);
|
||||||
|
}
|
||||||
|
}
|
||||||
87
wrapper/rabit_wrapper.h
Normal file
87
wrapper/rabit_wrapper.h
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
#ifndef RABIT_WRAPPER_H_
|
||||||
|
#define RABIT_WRAPPER_H_
|
||||||
|
/*!
|
||||||
|
* \file rabit_wrapper.h
|
||||||
|
* \author Tianqi Chen
|
||||||
|
* \brief a C style wrapper of rabit
|
||||||
|
* can be used to create wrapper of other languages
|
||||||
|
*/
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
#define RABIT_DLL __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define RABIT_DLL
|
||||||
|
#endif
|
||||||
|
// manually define unsign long
|
||||||
|
typedef unsigned long rbt_ulong;
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
/*!
|
||||||
|
* \brief intialize the rabit module, call this once before using anything
|
||||||
|
* \param argc number of arguments in argv
|
||||||
|
* \param argv the array of input arguments
|
||||||
|
*/
|
||||||
|
RABIT_DLL void RabitInit(int argc, char *argv[]);
|
||||||
|
/*!
|
||||||
|
* \brief finalize the rabit engine, call this function after you finished all jobs
|
||||||
|
*/
|
||||||
|
RABIT_DLL void RabitFinalize(void);
|
||||||
|
/*! \brief get rank of current process */
|
||||||
|
RABIT_DLL int RabitGetRank(void);
|
||||||
|
/*! \brief get total number of process */
|
||||||
|
RABIT_DLL int RabitGetWorldSize(void);
|
||||||
|
/*!
|
||||||
|
* \brief print the msg to the tracker,
|
||||||
|
* this function can be used to communicate the information of the progress to
|
||||||
|
* the user who monitors the tracker
|
||||||
|
* \param msg the message to be printed
|
||||||
|
*/
|
||||||
|
RABIT_DLL void RabitTrackerPrint(const char *msg);
|
||||||
|
/*!
|
||||||
|
* \brief get name of processor
|
||||||
|
* \param out_name hold output string
|
||||||
|
* \param out_len hold length of output string
|
||||||
|
* \param max_len maximum buffer length of input
|
||||||
|
*/
|
||||||
|
RABIT_DLL void RabitGetProcessorName(char *out_name,
|
||||||
|
rbt_ulong *out_len,
|
||||||
|
rbt_ulong max_len);
|
||||||
|
/*!
|
||||||
|
* \brief broadcast an memory region to all others from root
|
||||||
|
*
|
||||||
|
* Example: int a = 1; Broadcast(&a, sizeof(a), root);
|
||||||
|
* \param sendrecv_data the pointer to send or recive buffer,
|
||||||
|
* \param size the size of the data
|
||||||
|
* \param root the root of process
|
||||||
|
*/
|
||||||
|
RABIT_DLL void RabitBroadcast(void *sendrecv_data,
|
||||||
|
rbt_ulong size, int root);
|
||||||
|
/*!
|
||||||
|
* \brief perform in-place allreduce, on sendrecvbuf
|
||||||
|
* this function is NOT thread-safe
|
||||||
|
*
|
||||||
|
* Example Usage: the following code gives sum of the result
|
||||||
|
* vector<int> data(10);
|
||||||
|
* ...
|
||||||
|
* Allreduce<op::Sum>(&data[0], data.size());
|
||||||
|
* ...
|
||||||
|
* \param sendrecvbuf buffer for both sending and recving data
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param enum_dtype the enumeration of data type, see rabit::engine::mpi::DataType in engine.h of rabit include
|
||||||
|
* \param enum_op the enumeration of operation type, see rabit::engine::mpi::OpType in engine.h of rabit
|
||||||
|
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
||||||
|
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||||
|
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||||
|
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||||
|
*/
|
||||||
|
RABIT_DLL void RabitAllreduce(void *sendrecvbuf,
|
||||||
|
size_t count,
|
||||||
|
int enum_dtype,
|
||||||
|
int enum_op,
|
||||||
|
void (*prepare_fun)(void *arg),
|
||||||
|
void *prepare_arg);
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // C
|
||||||
|
#endif
|
||||||
|
#endif // XGBOOST_WRAPPER_H_
|
||||||
Loading…
x
Reference in New Issue
Block a user