From 270a49ee757cc0121de1f0f21b0d4b51026712d4 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 23 Jul 2015 22:22:52 -0700 Subject: [PATCH] add requirments --- Makefile | 4 +- include/rabit.h | 6 +- wrapper/rabit.py | 315 ++++++++++++++++++++++----------------- wrapper/requirements.txt | 2 + 4 files changed, 183 insertions(+), 144 deletions(-) create mode 100644 wrapper/requirements.txt diff --git a/Makefile b/Makefile index 92e899b51..21c6360a2 100644 --- a/Makefile +++ b/Makefile @@ -28,7 +28,7 @@ ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a lib/librabit_mock.a HEADERS=src/*.h include/*.h include/rabit/*.h DMLC=dmlc-core -.PHONY: clean all install mpi python lint doc +.PHONY: clean all install mpi python lint doc doxygen all: lib/librabit.a lib/librabit_mock.a wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so lib/librabit_base.a mpi: lib/librabit_mpi.a wrapper/librabit_wrapper_mpi.so @@ -68,7 +68,7 @@ $(SLIB) : lint: $(DMLC)/scripts/lint.py rabit $(LINT_LANG) src include wrapper -doc: +doc doxygen: cd include; doxygen ../doc/Doxyfile; cd - clean: diff --git a/include/rabit.h b/include/rabit.h index 8513d8251..dfc7cf369 100644 --- a/include/rabit.h +++ b/include/rabit.h @@ -138,7 +138,7 @@ inline void Broadcast(std::string *sendrecv_data, int root); */ template inline void Allreduce(DType *sendrecvbuf, size_t count, - void (*prepare_fun)(void *arg) = NULL, + void (*prepare_fun)(void *) = NULL, void *prepare_arg = NULL); // C++11 support for lambda prepare function #if DMLC_USE_CXX11 @@ -262,7 +262,7 @@ class Reducer { * \param prepare_arg argument used to pass into the lazy preprocessing function */ inline void Allreduce(DType *sendrecvbuf, size_t count, - void (*prepare_fun)(void *arg) = NULL, + void (*prepare_fun)(void *) = NULL, void *prepare_arg = NULL); #if DMLC_USE_CXX11 /*! @@ -306,7 +306,7 @@ class SerializeReducer { */ inline void Allreduce(DType *sendrecvobj, size_t max_nbyte, size_t count, - void (*prepare_fun)(void *arg) = NULL, + void (*prepare_fun)(void *) = NULL, void *prepare_arg = NULL); // C++11 support for lambda prepare function #if DMLC_USE_CXX11 diff --git a/wrapper/rabit.py b/wrapper/rabit.py index 34091003d..4822d70f1 100644 --- a/wrapper/rabit.py +++ b/wrapper/rabit.py @@ -1,6 +1,6 @@ """ -Python interface for rabit - Reliable Allreduce and Broadcast Library +Reliable Allreduce and Broadcast Library. + Author: Tianqi Chen """ # pylint: disable=unused-argument,invalid-name,global-statement,dangerous-default-value, @@ -11,37 +11,42 @@ import sys import warnings import numpy as np +# version information about the doc +__version__ = '1.0' + + if os.name == 'nt': WRAPPER_PATH = os.path.dirname(__file__) + '\\..\\windows\\x64\\Release\\rabit_wrapper%s.dll' else: WRAPPER_PATH = os.path.dirname(__file__) + '/librabit_wrapper%s.so' -rbtlib = None + +_LIB = None # load in xgboost library -def loadlib__(lib='standard'): - """Load rabit library""" - global rbtlib - if rbtlib != None: +def _loadlib(lib='standard'): + """Load rabit library.""" + global _LIB + if _LIB != None: warnings.warn('rabit.int call was ignored because it has'\ ' already been initialized', level=2) return if lib == 'standard': - rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH % '') + _LIB = ctypes.cdll.LoadLibrary(WRAPPER_PATH % '') elif lib == 'mock': - rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH % '_mock') + _LIB = ctypes.cdll.LoadLibrary(WRAPPER_PATH % '_mock') elif lib == 'mpi': - rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH % '_mpi') + _LIB = ctypes.cdll.LoadLibrary(WRAPPER_PATH % '_mpi') else: raise Exception('unknown rabit lib %s, can be standard, mock, mpi' % lib) - rbtlib.RabitGetRank.restype = ctypes.c_int - rbtlib.RabitGetWorldSize.restype = ctypes.c_int - rbtlib.RabitVersionNumber.restype = ctypes.c_int + _LIB.RabitGetRank.restype = ctypes.c_int + _LIB.RabitGetWorldSize.restype = ctypes.c_int + _LIB.RabitVersionNumber.restype = ctypes.c_int -def unloadlib__(): - """Unload rabit library""" - global rbtlib - del rbtlib - rbtlib = None +def _unloadlib(): + """Unload rabit library.""" + global _LIB + del _LIB + _LIB = None # reduction operators MAX = 0 @@ -49,101 +54,110 @@ MIN = 1 SUM = 2 BITOR = 3 -def check_err__(): - """ - reserved function used to check error +def _check_err(): + """Reserved function used to check error. """ return -def init(args=sys.argv, lib='standard'): +def init(args=None, lib='standard'): + """Intialize the rabit module, call this once before using anything. + + Parameters + ---------- + args: list of str, optional + The list of arguments used to initialized the rabit + usually you need to pass in sys.argv. + Defaults to sys.argv when it is None. + lib: {'standard', 'mock', 'mpi'} + Type of library we want to load """ - intialize the rabit module, call this once before using anything - Arguments: - args: list(string) [default=sys.argv] - the list of arguments used to initialized the rabit - usually you need to pass in sys.argv - with_mock: boolean [default=False] - Whether initialize the mock test module - """ - loadlib__(lib) + if args is None: + args = sys.argv + _loadlib(lib) arr = (ctypes.c_char_p * len(args))() arr[:] = args - rbtlib.RabitInit(len(args), arr) - check_err__() + _LIB.RabitInit(len(args), arr) + _check_err() def finalize(): + """Finalize the rabit engine. + + Call this function after you finished all jobs. """ - finalize the rabit engine, call this function after you finished all jobs - """ - rbtlib.RabitFinalize() - check_err__() - unloadlib__() + _LIB.RabitFinalize() + _check_err() + _unloadlib() def get_rank(): + """Get rank of current process. + + Returns + ------- + rank : int + Rank of current process. """ - Returns rank of current process - """ - ret = rbtlib.RabitGetRank() - check_err__() + ret = _LIB.RabitGetRank() + _check_err() return ret def get_world_size(): + """Get total number workers. + + Returns + ------- + n : int + Total number of process. """ - Returns get total number of process - """ - ret = rbtlib.RabitGetWorldSize() - check_err__() + ret = _LIB.RabitGetWorldSize() + _check_err() return ret def tracker_print(msg): - """ - print message to the tracker - this function can be used to communicate the information of the progress - to the tracker + """Print message to the tracker. + + This function can be used to communicate the information of + the progress to the tracker + + Parameters + ---------- + msg : str + The message to be printed to tracker. """ if not isinstance(msg, str): msg = str(msg) - rbtlib.RabitTrackerPrint(ctypes.c_char_p(msg).encode('utf-8')) - check_err__() + _LIB.RabitTrackerPrint(ctypes.c_char_p(msg).encode('utf-8')) + _check_err() def get_processor_name(): - """ - Returns the name of processor(host) + """Get the processor name. + + Returns + ------- + name : str + the name of processor(host) """ mxlen = 256 length = ctypes.c_ulong() buf = ctypes.create_string_buffer(mxlen) - rbtlib.RabitGetProcessorName(buf, ctypes.byref(length), + _LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen) - check_err__() + _check_err() return buf.value def broadcast(data, root): - """ - broadcast object from one node to all other nodes - this function will return the broadcasted object + """Broadcast object from one node to all other nodes. - Example: the following example broadcast hello from rank 0 to all other nodes - ```python - rabit.init() - n = 3 - rank = rabit.get_rank() - s = None - if rank == 0: - s = {'hello world':100, 2:3} - print '@node[%d] before-broadcast: s=\"%s\"' % (rank, str(s)) - s = rabit.broadcast(s, 0) - print '@node[%d] after-broadcast: s=\"%s\"' % (rank, str(s)) - rabit.finalize() - ``` + Parameters + ---------- + data : any type that can be pickled + Input data, if current rank does not equal root, this can be None + root : int + Rank of the node to broadcast data from. - Arguments: - data: anytype that can be pickled - input data, if current rank does not equal root, this can be None - root: int - rank of the node to broadcast data from - Returns: - the result of broadcast + Returns + ------- + object : int + the result of broadcast. """ rank = get_rank() length = ctypes.c_ulong() @@ -152,22 +166,22 @@ def broadcast(data, root): s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL) length.value = len(s) # run first broadcast - rbtlib.RabitBroadcast(ctypes.byref(length), + _LIB.RabitBroadcast(ctypes.byref(length), ctypes.sizeof(ctypes.c_ulong), root) - check_err__() + _check_err() if root != rank: dptr = (ctypes.c_char * length.value)() # run second - rbtlib.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p), + _LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p), length.value, root) - check_err__() + _check_err() data = pickle.loads(dptr.raw) del dptr else: - rbtlib.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p), + _LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p), length.value, root) - check_err__() + _check_err() del s return data @@ -184,20 +198,28 @@ DTYPE_ENUM__ = { } def allreduce(data, op, prepare_fun=None): - """ - perform allreduce, return the result, this function is not thread-safe - Arguments: - data: numpy ndarray - input data - op: int - reduction operators, can be MIN, MAX, SUM, BITOR - prepare_fun: lambda data - Lazy preprocessing function, if it is not None, prepare_fun(data) - will be called by the function before performing allreduce, to intialize the data - If the result of Allreduce can be recovered directly, - then prepare_fun will NOT be called - Returns: - the result of allreduce, have same shape as data + """Perform allreduce, return the result. + + Parameters + ---------- + data: numpy array + Input data. + op: int + Reduction operators, can be MIN, MAX, SUM, BITOR + prepare_fun: function + Lazy preprocessing function, if it is not None, prepare_fun(data) + will be called by the function before performing allreduce, to intialize the data + If the result of Allreduce can be recovered directly, + then prepare_fun will NOT be called + + Returns + ------- + result : array_like + The result of allreduce, have same shape as data + + Notes + ----- + This function is not thread-safe. """ if not isinstance(data, np.ndarray): raise Exception('allreduce only takes in numpy.ndarray') @@ -207,7 +229,7 @@ def allreduce(data, op, prepare_fun=None): if buf.dtype not in DTYPE_ENUM__: raise Exception('data type %s not supported' % str(buf.dtype)) if prepare_fun is None: - rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), + _LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), buf.size, DTYPE_ENUM__[buf.dtype], op, None, None) else: @@ -215,14 +237,14 @@ def allreduce(data, op, prepare_fun=None): def pfunc(args): """prepare function.""" prepare_fun(data) - rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), + _LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), buf.size, DTYPE_ENUM__[buf.dtype], op, func_ptr(pfunc), None) - check_err__() + _check_err() return buf -def load_model__(ptr, length): +def _load_model(ptr, length): """ Internal function used by the module, unpickle a model from a buffer specified by ptr, length @@ -236,12 +258,16 @@ def load_model__(ptr, length): return pickle.loads(data.raw) def load_checkpoint(with_local=False): - """ - load latest check point - Arguments: - with_local: boolean [default = False] - whether the checkpoint contains local model - Returns: + """Load latest check point. + + Parameters + ---------- + with_local: bool, optional + whether the checkpoint contains local model + + Returns + ------- + tuple : tuple if with_local: return (version, gobal_model, local_model) else return (version, gobal_model) if returned version == 0, this means no model has been CheckPointed @@ -252,62 +278,73 @@ def load_checkpoint(with_local=False): if with_local: lptr = ctypes.POINTER(ctypes.c_char)() local_len = ctypes.c_ulong() - version = rbtlib.RabitLoadCheckPoint( + version = _LIB.RabitLoadCheckPoint( ctypes.byref(gptr), ctypes.byref(global_len), ctypes.byref(lptr), ctypes.byref(local_len)) - check_err__() + _check_err() if version == 0: return (version, None, None) return (version, - load_model__(gptr, global_len.value), - load_model__(lptr, local_len.value)) + _load_model(gptr, global_len.value), + _load_model(lptr, local_len.value)) else: - version = rbtlib.RabitLoadCheckPoint( + version = _LIB.RabitLoadCheckPoint( ctypes.byref(gptr), ctypes.byref(global_len), None, None) - check_err__() + _check_err() if version == 0: return (version, None) return (version, - load_model__(gptr, global_len.value)) + _load_model(gptr, global_len.value)) def checkpoint(global_model, local_model=None): - """ - checkpoint the model, meaning we finished a stage of execution - every time we call check point, there is a version number which will increase by one + """Checkpoint the model. - Arguments: - global_model: anytype that can be pickled - globally shared model/state when calling this function, - the caller need to gauranttees that global_model is the same in all nodes - local_model: anytype that can be pickled - local model, that is specific to current node/rank. - This can be None when no local state is needed. - local_model requires explicit replication of the model for fault-tolerance, - which will bring replication cost in checkpoint function, - while global_model do not need explicit replication. - It is recommended to use global_model if possible + This means we finished a stage of execution. + Every time we call check point, there is a version number which will increase by one. + + Parameters + ---------- + global_model: anytype that can be pickled + globally shared model/state when calling this function, + the caller need to gauranttees that global_model is the same in all nodes + + local_model: anytype that can be pickled + Local model, that is specific to current node/rank. + This can be None when no local state is needed. + + Notes + ----- + local_model requires explicit replication of the model for fault-tolerance. + This will bring replication cost in checkpoint function. + while global_model do not need explicit replication. + It is recommended to use global_model if possible. """ sglobal = pickle.dumps(global_model) if local_model is None: - rbtlib.RabitCheckPoint(sglobal, len(sglobal), None, 0) - check_err__() + _LIB.RabitCheckPoint(sglobal, len(sglobal), None, 0) + _check_err() del sglobal else: slocal = pickle.dumps(local_model) - rbtlib.RabitCheckPoint(sglobal, len(sglobal), slocal, len(slocal)) - check_err__() + _LIB.RabitCheckPoint(sglobal, len(sglobal), slocal, len(slocal)) + _check_err() del slocal del sglobal def version_number(): + """Returns version number of current stored model. + + This means how many calls to CheckPoint we made so far. + + Returns + ------- + version : int + Version number of currently stored model """ - Returns version number of current stored model, - which means how many calls to CheckPoint we made so far - """ - ret = rbtlib.RabitVersionNumber() - check_err__() + ret = _LIB.RabitVersionNumber() + _check_err() return ret diff --git a/wrapper/requirements.txt b/wrapper/requirements.txt new file mode 100644 index 000000000..42d8ea92c --- /dev/null +++ b/wrapper/requirements.txt @@ -0,0 +1,2 @@ +numpy==1.8.1 +