ok
This commit is contained in:
parent
c125d2a8bb
commit
cef360d782
@ -14,7 +14,6 @@ import numpy as np
|
|||||||
# version information about the doc
|
# version information about the doc
|
||||||
__version__ = '1.0'
|
__version__ = '1.0'
|
||||||
|
|
||||||
|
|
||||||
if os.name == 'nt':
|
if os.name == 'nt':
|
||||||
WRAPPER_PATH = os.path.dirname(__file__) + '\\..\\windows\\x64\\Release\\rabit_wrapper%s.dll'
|
WRAPPER_PATH = os.path.dirname(__file__) + '\\..\\windows\\x64\\Release\\rabit_wrapper%s.dll'
|
||||||
else:
|
else:
|
||||||
@ -54,11 +53,6 @@ MIN = 1
|
|||||||
SUM = 2
|
SUM = 2
|
||||||
BITOR = 3
|
BITOR = 3
|
||||||
|
|
||||||
def _check_err():
|
|
||||||
"""Reserved function used to check error.
|
|
||||||
"""
|
|
||||||
return
|
|
||||||
|
|
||||||
def init(args=None, lib='standard'):
|
def init(args=None, lib='standard'):
|
||||||
"""Intialize the rabit module, call this once before using anything.
|
"""Intialize the rabit module, call this once before using anything.
|
||||||
|
|
||||||
@ -77,7 +71,6 @@ def init(args=None, lib='standard'):
|
|||||||
arr = (ctypes.c_char_p * len(args))()
|
arr = (ctypes.c_char_p * len(args))()
|
||||||
arr[:] = args
|
arr[:] = args
|
||||||
_LIB.RabitInit(len(args), arr)
|
_LIB.RabitInit(len(args), arr)
|
||||||
_check_err()
|
|
||||||
|
|
||||||
def finalize():
|
def finalize():
|
||||||
"""Finalize the rabit engine.
|
"""Finalize the rabit engine.
|
||||||
@ -85,7 +78,6 @@ def finalize():
|
|||||||
Call this function after you finished all jobs.
|
Call this function after you finished all jobs.
|
||||||
"""
|
"""
|
||||||
_LIB.RabitFinalize()
|
_LIB.RabitFinalize()
|
||||||
_check_err()
|
|
||||||
_unloadlib()
|
_unloadlib()
|
||||||
|
|
||||||
def get_rank():
|
def get_rank():
|
||||||
@ -97,7 +89,6 @@ def get_rank():
|
|||||||
Rank of current process.
|
Rank of current process.
|
||||||
"""
|
"""
|
||||||
ret = _LIB.RabitGetRank()
|
ret = _LIB.RabitGetRank()
|
||||||
_check_err()
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def get_world_size():
|
def get_world_size():
|
||||||
@ -109,7 +100,6 @@ def get_world_size():
|
|||||||
Total number of process.
|
Total number of process.
|
||||||
"""
|
"""
|
||||||
ret = _LIB.RabitGetWorldSize()
|
ret = _LIB.RabitGetWorldSize()
|
||||||
_check_err()
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def tracker_print(msg):
|
def tracker_print(msg):
|
||||||
@ -126,7 +116,6 @@ def tracker_print(msg):
|
|||||||
if not isinstance(msg, str):
|
if not isinstance(msg, str):
|
||||||
msg = str(msg)
|
msg = str(msg)
|
||||||
_LIB.RabitTrackerPrint(ctypes.c_char_p(msg).encode('utf-8'))
|
_LIB.RabitTrackerPrint(ctypes.c_char_p(msg).encode('utf-8'))
|
||||||
_check_err()
|
|
||||||
|
|
||||||
def get_processor_name():
|
def get_processor_name():
|
||||||
"""Get the processor name.
|
"""Get the processor name.
|
||||||
@ -139,9 +128,7 @@ def get_processor_name():
|
|||||||
mxlen = 256
|
mxlen = 256
|
||||||
length = ctypes.c_ulong()
|
length = ctypes.c_ulong()
|
||||||
buf = ctypes.create_string_buffer(mxlen)
|
buf = ctypes.create_string_buffer(mxlen)
|
||||||
_LIB.RabitGetProcessorName(buf, ctypes.byref(length),
|
_LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen)
|
||||||
mxlen)
|
|
||||||
_check_err()
|
|
||||||
return buf.value
|
return buf.value
|
||||||
|
|
||||||
def broadcast(data, root):
|
def broadcast(data, root):
|
||||||
@ -167,21 +154,17 @@ def broadcast(data, root):
|
|||||||
length.value = len(s)
|
length.value = len(s)
|
||||||
# run first broadcast
|
# run first broadcast
|
||||||
_LIB.RabitBroadcast(ctypes.byref(length),
|
_LIB.RabitBroadcast(ctypes.byref(length),
|
||||||
ctypes.sizeof(ctypes.c_ulong),
|
ctypes.sizeof(ctypes.c_ulong), root)
|
||||||
root)
|
|
||||||
_check_err()
|
|
||||||
if root != rank:
|
if root != rank:
|
||||||
dptr = (ctypes.c_char * length.value)()
|
dptr = (ctypes.c_char * length.value)()
|
||||||
# run second
|
# run second
|
||||||
_LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
|
_LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
|
||||||
length.value, root)
|
length.value, root)
|
||||||
_check_err()
|
|
||||||
data = pickle.loads(dptr.raw)
|
data = pickle.loads(dptr.raw)
|
||||||
del dptr
|
del dptr
|
||||||
else:
|
else:
|
||||||
_LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
|
_LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
|
||||||
length.value, root)
|
length.value, root)
|
||||||
_check_err()
|
|
||||||
del s
|
del s
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@ -240,7 +223,6 @@ def allreduce(data, op, prepare_fun=None):
|
|||||||
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||||
op, func_ptr(pfunc), None)
|
op, func_ptr(pfunc), None)
|
||||||
_check_err()
|
|
||||||
return buf
|
return buf
|
||||||
|
|
||||||
|
|
||||||
@ -283,7 +265,6 @@ def load_checkpoint(with_local=False):
|
|||||||
ctypes.byref(global_len),
|
ctypes.byref(global_len),
|
||||||
ctypes.byref(lptr),
|
ctypes.byref(lptr),
|
||||||
ctypes.byref(local_len))
|
ctypes.byref(local_len))
|
||||||
_check_err()
|
|
||||||
if version == 0:
|
if version == 0:
|
||||||
return (version, None, None)
|
return (version, None, None)
|
||||||
return (version,
|
return (version,
|
||||||
@ -294,7 +275,6 @@ def load_checkpoint(with_local=False):
|
|||||||
ctypes.byref(gptr),
|
ctypes.byref(gptr),
|
||||||
ctypes.byref(global_len),
|
ctypes.byref(global_len),
|
||||||
None, None)
|
None, None)
|
||||||
_check_err()
|
|
||||||
if version == 0:
|
if version == 0:
|
||||||
return (version, None)
|
return (version, None)
|
||||||
return (version,
|
return (version,
|
||||||
@ -326,12 +306,10 @@ def checkpoint(global_model, local_model=None):
|
|||||||
sglobal = pickle.dumps(global_model)
|
sglobal = pickle.dumps(global_model)
|
||||||
if local_model is None:
|
if local_model is None:
|
||||||
_LIB.RabitCheckPoint(sglobal, len(sglobal), None, 0)
|
_LIB.RabitCheckPoint(sglobal, len(sglobal), None, 0)
|
||||||
_check_err()
|
|
||||||
del sglobal
|
del sglobal
|
||||||
else:
|
else:
|
||||||
slocal = pickle.dumps(local_model)
|
slocal = pickle.dumps(local_model)
|
||||||
_LIB.RabitCheckPoint(sglobal, len(sglobal), slocal, len(slocal))
|
_LIB.RabitCheckPoint(sglobal, len(sglobal), slocal, len(slocal))
|
||||||
_check_err()
|
|
||||||
del slocal
|
del slocal
|
||||||
del sglobal
|
del sglobal
|
||||||
|
|
||||||
@ -346,5 +324,4 @@ def version_number():
|
|||||||
Version number of currently stored model
|
Version number of currently stored model
|
||||||
"""
|
"""
|
||||||
ret = _LIB.RabitVersionNumber()
|
ret = _LIB.RabitVersionNumber()
|
||||||
_check_err()
|
|
||||||
return ret
|
return ret
|
||||||
|
|||||||
@ -1,2 +0,0 @@
|
|||||||
numpy==1.8.1
|
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user