This commit is contained in:
tqchen 2015-07-26 12:15:00 -07:00
parent c125d2a8bb
commit cef360d782
2 changed files with 8 additions and 33 deletions

View File

@ -14,7 +14,6 @@ import numpy as np
# version information about the doc # version information about the doc
__version__ = '1.0' __version__ = '1.0'
if os.name == 'nt': if os.name == 'nt':
WRAPPER_PATH = os.path.dirname(__file__) + '\\..\\windows\\x64\\Release\\rabit_wrapper%s.dll' WRAPPER_PATH = os.path.dirname(__file__) + '\\..\\windows\\x64\\Release\\rabit_wrapper%s.dll'
else: else:
@ -54,11 +53,6 @@ MIN = 1
SUM = 2 SUM = 2
BITOR = 3 BITOR = 3
def _check_err():
"""Reserved function used to check error.
"""
return
def init(args=None, lib='standard'): def init(args=None, lib='standard'):
"""Intialize the rabit module, call this once before using anything. """Intialize the rabit module, call this once before using anything.
@ -77,7 +71,6 @@ def init(args=None, lib='standard'):
arr = (ctypes.c_char_p * len(args))() arr = (ctypes.c_char_p * len(args))()
arr[:] = args arr[:] = args
_LIB.RabitInit(len(args), arr) _LIB.RabitInit(len(args), arr)
_check_err()
def finalize(): def finalize():
"""Finalize the rabit engine. """Finalize the rabit engine.
@ -85,7 +78,6 @@ def finalize():
Call this function after you finished all jobs. Call this function after you finished all jobs.
""" """
_LIB.RabitFinalize() _LIB.RabitFinalize()
_check_err()
_unloadlib() _unloadlib()
def get_rank(): def get_rank():
@ -97,7 +89,6 @@ def get_rank():
Rank of current process. Rank of current process.
""" """
ret = _LIB.RabitGetRank() ret = _LIB.RabitGetRank()
_check_err()
return ret return ret
def get_world_size(): def get_world_size():
@ -109,7 +100,6 @@ def get_world_size():
Total number of process. Total number of process.
""" """
ret = _LIB.RabitGetWorldSize() ret = _LIB.RabitGetWorldSize()
_check_err()
return ret return ret
def tracker_print(msg): def tracker_print(msg):
@ -126,7 +116,6 @@ def tracker_print(msg):
if not isinstance(msg, str): if not isinstance(msg, str):
msg = str(msg) msg = str(msg)
_LIB.RabitTrackerPrint(ctypes.c_char_p(msg).encode('utf-8')) _LIB.RabitTrackerPrint(ctypes.c_char_p(msg).encode('utf-8'))
_check_err()
def get_processor_name(): def get_processor_name():
"""Get the processor name. """Get the processor name.
@ -139,9 +128,7 @@ def get_processor_name():
mxlen = 256 mxlen = 256
length = ctypes.c_ulong() length = ctypes.c_ulong()
buf = ctypes.create_string_buffer(mxlen) buf = ctypes.create_string_buffer(mxlen)
_LIB.RabitGetProcessorName(buf, ctypes.byref(length), _LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen)
mxlen)
_check_err()
return buf.value return buf.value
def broadcast(data, root): def broadcast(data, root):
@ -167,21 +154,17 @@ def broadcast(data, root):
length.value = len(s) length.value = len(s)
# run first broadcast # run first broadcast
_LIB.RabitBroadcast(ctypes.byref(length), _LIB.RabitBroadcast(ctypes.byref(length),
ctypes.sizeof(ctypes.c_ulong), ctypes.sizeof(ctypes.c_ulong), root)
root)
_check_err()
if root != rank: if root != rank:
dptr = (ctypes.c_char * length.value)() dptr = (ctypes.c_char * length.value)()
# run second # run second
_LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p), _LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
length.value, root) length.value, root)
_check_err()
data = pickle.loads(dptr.raw) data = pickle.loads(dptr.raw)
del dptr del dptr
else: else:
_LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p), _LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
length.value, root) length.value, root)
_check_err()
del s del s
return data return data
@ -230,17 +213,16 @@ def allreduce(data, op, prepare_fun=None):
raise Exception('data type %s not supported' % str(buf.dtype)) raise Exception('data type %s not supported' % str(buf.dtype))
if prepare_fun is None: if prepare_fun is None:
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), _LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype], buf.size, DTYPE_ENUM__[buf.dtype],
op, None, None) op, None, None)
else: else:
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p) func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
def pfunc(args): def pfunc(args):
"""prepare function.""" """prepare function."""
prepare_fun(data) prepare_fun(data)
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), _LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype], buf.size, DTYPE_ENUM__[buf.dtype],
op, func_ptr(pfunc), None) op, func_ptr(pfunc), None)
_check_err()
return buf return buf
@ -283,7 +265,6 @@ def load_checkpoint(with_local=False):
ctypes.byref(global_len), ctypes.byref(global_len),
ctypes.byref(lptr), ctypes.byref(lptr),
ctypes.byref(local_len)) ctypes.byref(local_len))
_check_err()
if version == 0: if version == 0:
return (version, None, None) return (version, None, None)
return (version, return (version,
@ -294,7 +275,6 @@ def load_checkpoint(with_local=False):
ctypes.byref(gptr), ctypes.byref(gptr),
ctypes.byref(global_len), ctypes.byref(global_len),
None, None) None, None)
_check_err()
if version == 0: if version == 0:
return (version, None) return (version, None)
return (version, return (version,
@ -326,12 +306,10 @@ def checkpoint(global_model, local_model=None):
sglobal = pickle.dumps(global_model) sglobal = pickle.dumps(global_model)
if local_model is None: if local_model is None:
_LIB.RabitCheckPoint(sglobal, len(sglobal), None, 0) _LIB.RabitCheckPoint(sglobal, len(sglobal), None, 0)
_check_err()
del sglobal del sglobal
else: else:
slocal = pickle.dumps(local_model) slocal = pickle.dumps(local_model)
_LIB.RabitCheckPoint(sglobal, len(sglobal), slocal, len(slocal)) _LIB.RabitCheckPoint(sglobal, len(sglobal), slocal, len(slocal))
_check_err()
del slocal del slocal
del sglobal del sglobal
@ -346,5 +324,4 @@ def version_number():
Version number of currently stored model Version number of currently stored model
""" """
ret = _LIB.RabitVersionNumber() ret = _LIB.RabitVersionNumber()
_check_err()
return ret return ret

View File

@ -1,2 +0,0 @@
numpy==1.8.1