From cef360d78269dea24c491df4e742c5d2741793fa Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 26 Jul 2015 12:15:00 -0700 Subject: [PATCH] ok --- wrapper/rabit.py | 39 ++++++++------------------------------- wrapper/requirements.txt | 2 -- 2 files changed, 8 insertions(+), 33 deletions(-) delete mode 100644 wrapper/requirements.txt diff --git a/wrapper/rabit.py b/wrapper/rabit.py index 4822d70f1..91ce3e6ae 100644 --- a/wrapper/rabit.py +++ b/wrapper/rabit.py @@ -14,7 +14,6 @@ import numpy as np # version information about the doc __version__ = '1.0' - if os.name == 'nt': WRAPPER_PATH = os.path.dirname(__file__) + '\\..\\windows\\x64\\Release\\rabit_wrapper%s.dll' else: @@ -54,11 +53,6 @@ MIN = 1 SUM = 2 BITOR = 3 -def _check_err(): - """Reserved function used to check error. - """ - return - def init(args=None, lib='standard'): """Intialize the rabit module, call this once before using anything. @@ -77,7 +71,6 @@ def init(args=None, lib='standard'): arr = (ctypes.c_char_p * len(args))() arr[:] = args _LIB.RabitInit(len(args), arr) - _check_err() def finalize(): """Finalize the rabit engine. @@ -85,7 +78,6 @@ def finalize(): Call this function after you finished all jobs. """ _LIB.RabitFinalize() - _check_err() _unloadlib() def get_rank(): @@ -97,7 +89,6 @@ def get_rank(): Rank of current process. """ ret = _LIB.RabitGetRank() - _check_err() return ret def get_world_size(): @@ -109,7 +100,6 @@ def get_world_size(): Total number of process. """ ret = _LIB.RabitGetWorldSize() - _check_err() return ret def tracker_print(msg): @@ -126,7 +116,6 @@ def tracker_print(msg): if not isinstance(msg, str): msg = str(msg) _LIB.RabitTrackerPrint(ctypes.c_char_p(msg).encode('utf-8')) - _check_err() def get_processor_name(): """Get the processor name. @@ -139,9 +128,7 @@ def get_processor_name(): mxlen = 256 length = ctypes.c_ulong() buf = ctypes.create_string_buffer(mxlen) - _LIB.RabitGetProcessorName(buf, ctypes.byref(length), - mxlen) - _check_err() + _LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen) return buf.value def broadcast(data, root): @@ -167,21 +154,17 @@ def broadcast(data, root): length.value = len(s) # run first broadcast _LIB.RabitBroadcast(ctypes.byref(length), - ctypes.sizeof(ctypes.c_ulong), - root) - _check_err() + ctypes.sizeof(ctypes.c_ulong), root) if root != rank: dptr = (ctypes.c_char * length.value)() # run second _LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p), - length.value, root) - _check_err() + length.value, root) data = pickle.loads(dptr.raw) del dptr else: _LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p), - length.value, root) - _check_err() + length.value, root) del s return data @@ -230,17 +213,16 @@ def allreduce(data, op, prepare_fun=None): raise Exception('data type %s not supported' % str(buf.dtype)) if prepare_fun is None: _LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), - buf.size, DTYPE_ENUM__[buf.dtype], - op, None, None) + buf.size, DTYPE_ENUM__[buf.dtype], + op, None, None) else: func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p) def pfunc(args): """prepare function.""" prepare_fun(data) _LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), - buf.size, DTYPE_ENUM__[buf.dtype], - op, func_ptr(pfunc), None) - _check_err() + buf.size, DTYPE_ENUM__[buf.dtype], + op, func_ptr(pfunc), None) return buf @@ -283,7 +265,6 @@ def load_checkpoint(with_local=False): ctypes.byref(global_len), ctypes.byref(lptr), ctypes.byref(local_len)) - _check_err() if version == 0: return (version, None, None) return (version, @@ -294,7 +275,6 @@ def load_checkpoint(with_local=False): ctypes.byref(gptr), ctypes.byref(global_len), None, None) - _check_err() if version == 0: return (version, None) return (version, @@ -326,12 +306,10 @@ def checkpoint(global_model, local_model=None): sglobal = pickle.dumps(global_model) if local_model is None: _LIB.RabitCheckPoint(sglobal, len(sglobal), None, 0) - _check_err() del sglobal else: slocal = pickle.dumps(local_model) _LIB.RabitCheckPoint(sglobal, len(sglobal), slocal, len(slocal)) - _check_err() del slocal del sglobal @@ -346,5 +324,4 @@ def version_number(): Version number of currently stored model """ ret = _LIB.RabitVersionNumber() - _check_err() return ret diff --git a/wrapper/requirements.txt b/wrapper/requirements.txt deleted file mode 100644 index 42d8ea92c..000000000 --- a/wrapper/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -numpy==1.8.1 -