e81a11d Merge pull request #25 from daiyl0320/master 35c3b37 add retry mechanism to ConnectTracker and modify Listen backlog to 128 in rabit_traker.py c71ed6f try deply doxygen 62e5647 try deply doxygen 732f1c6 try 2fa6e02 ok 0537665 minor 7b59dcb minor 5934950 new doc f538187 ok 44b6049 new doc 387339b add more 9d4397a chg 2879a48 chg 30e3110 ok 9ff0301 add link translation 6b629c2 k 32e1955 ok 8f4839d fix 93137b2 ok 7eeeb79 reload recommonmark a8f00cc minor 19b0f01 ok dd01184 minor c1cdc19 minor fcf0f43 try rst cbc21ae try 62ddfa7 tiny aefc05c final change 2aee9b4 minor fe4e7c2 ok 8001983 change to subtitle 5ca33e4 ok 88f7d24 update guide 29d43ab add code fe8bb3b minor hack for readthedocs 229c71d Merge branch 'master' of ssh://github.com/dmlc/rabit 7424218 ok d1d45bb Update README.md 1e8813f Update README.md 1ccc990 Update README.md 0323e06 remove readme 679a835 remove theme 7ea5b7c remove numpydoc to napoleon b73e2be Merge branch 'master' of ssh://github.com/dmlc/rabit 1742283 ok 1838e25 Update python-requirements.txt bc4e957 ok fba6fc2 ok 0251101 ok d50b905 ok d4f2509 ok cdf401a ok fef0ef2 new doc cef360d ok c125d2a ok 270a49e add requirments 744f901 get the basic doc 1cb5cad Merge branch 'master' of ssh://github.com/dmlc/rabit 8cc07ba minor d74f126 Update .travis.yml 52b3dcd Update .travis.yml 099581b Update .travis.yml 1258046 Update .travis.yml 7addac9 Update Makefile 0ea7adf Update .travis.yml f858856 Update travis_script.sh d8eac4a Update README.md 3cc49ad lint and travis ceedf4e fix fd8920c fix win32 8bbed35 modify 9520b90 Merge pull request #14 from dmlc/hjk41 df14bb1 fix type f441dc7 replace tab with blankspace 2467942 remove unnecessary include 181ef47 defined long long and ulonglong 1582180 use int32_t to define int and int64_t to define long. in VC long is 32bit e0b7da0 fix git-subtree-dir: subtree/rabit git-subtree-split: e81a11dd7ee3cff87a38a42901315821df018bae
328 lines
9.2 KiB
Python
328 lines
9.2 KiB
Python
"""
|
|
Reliable Allreduce and Broadcast Library.
|
|
|
|
Author: Tianqi Chen
|
|
"""
|
|
# pylint: disable=unused-argument,invalid-name,global-statement,dangerous-default-value,
|
|
import cPickle as pickle
|
|
import ctypes
|
|
import os
|
|
import sys
|
|
import warnings
|
|
import numpy as np
|
|
|
|
# version information about the doc
|
|
__version__ = '1.0'
|
|
|
|
if os.name == 'nt':
|
|
WRAPPER_PATH = os.path.dirname(__file__) + '\\..\\windows\\x64\\Release\\rabit_wrapper%s.dll'
|
|
else:
|
|
WRAPPER_PATH = os.path.dirname(__file__) + '/librabit_wrapper%s.so'
|
|
|
|
_LIB = None
|
|
|
|
# load in xgboost library
|
|
def _loadlib(lib='standard'):
|
|
"""Load rabit library."""
|
|
global _LIB
|
|
if _LIB != None:
|
|
warnings.warn('rabit.int call was ignored because it has'\
|
|
' already been initialized', level=2)
|
|
return
|
|
if lib == 'standard':
|
|
_LIB = ctypes.cdll.LoadLibrary(WRAPPER_PATH % '')
|
|
elif lib == 'mock':
|
|
_LIB = ctypes.cdll.LoadLibrary(WRAPPER_PATH % '_mock')
|
|
elif lib == 'mpi':
|
|
_LIB = ctypes.cdll.LoadLibrary(WRAPPER_PATH % '_mpi')
|
|
else:
|
|
raise Exception('unknown rabit lib %s, can be standard, mock, mpi' % lib)
|
|
_LIB.RabitGetRank.restype = ctypes.c_int
|
|
_LIB.RabitGetWorldSize.restype = ctypes.c_int
|
|
_LIB.RabitVersionNumber.restype = ctypes.c_int
|
|
|
|
def _unloadlib():
|
|
"""Unload rabit library."""
|
|
global _LIB
|
|
del _LIB
|
|
_LIB = None
|
|
|
|
# reduction operators
|
|
MAX = 0
|
|
MIN = 1
|
|
SUM = 2
|
|
BITOR = 3
|
|
|
|
def init(args=None, lib='standard'):
|
|
"""Intialize the rabit module, call this once before using anything.
|
|
|
|
Parameters
|
|
----------
|
|
args: list of str, optional
|
|
The list of arguments used to initialized the rabit
|
|
usually you need to pass in sys.argv.
|
|
Defaults to sys.argv when it is None.
|
|
lib: {'standard', 'mock', 'mpi'}
|
|
Type of library we want to load
|
|
"""
|
|
if args is None:
|
|
args = sys.argv
|
|
_loadlib(lib)
|
|
arr = (ctypes.c_char_p * len(args))()
|
|
arr[:] = args
|
|
_LIB.RabitInit(len(args), arr)
|
|
|
|
def finalize():
|
|
"""Finalize the rabit engine.
|
|
|
|
Call this function after you finished all jobs.
|
|
"""
|
|
_LIB.RabitFinalize()
|
|
_unloadlib()
|
|
|
|
def get_rank():
|
|
"""Get rank of current process.
|
|
|
|
Returns
|
|
-------
|
|
rank : int
|
|
Rank of current process.
|
|
"""
|
|
ret = _LIB.RabitGetRank()
|
|
return ret
|
|
|
|
def get_world_size():
|
|
"""Get total number workers.
|
|
|
|
Returns
|
|
-------
|
|
n : int
|
|
Total number of process.
|
|
"""
|
|
ret = _LIB.RabitGetWorldSize()
|
|
return ret
|
|
|
|
def tracker_print(msg):
|
|
"""Print message to the tracker.
|
|
|
|
This function can be used to communicate the information of
|
|
the progress to the tracker
|
|
|
|
Parameters
|
|
----------
|
|
msg : str
|
|
The message to be printed to tracker.
|
|
"""
|
|
if not isinstance(msg, str):
|
|
msg = str(msg)
|
|
_LIB.RabitTrackerPrint(ctypes.c_char_p(msg).encode('utf-8'))
|
|
|
|
def get_processor_name():
|
|
"""Get the processor name.
|
|
|
|
Returns
|
|
-------
|
|
name : str
|
|
the name of processor(host)
|
|
"""
|
|
mxlen = 256
|
|
length = ctypes.c_ulong()
|
|
buf = ctypes.create_string_buffer(mxlen)
|
|
_LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen)
|
|
return buf.value
|
|
|
|
def broadcast(data, root):
|
|
"""Broadcast object from one node to all other nodes.
|
|
|
|
Parameters
|
|
----------
|
|
data : any type that can be pickled
|
|
Input data, if current rank does not equal root, this can be None
|
|
root : int
|
|
Rank of the node to broadcast data from.
|
|
|
|
Returns
|
|
-------
|
|
object : int
|
|
the result of broadcast.
|
|
"""
|
|
rank = get_rank()
|
|
length = ctypes.c_ulong()
|
|
if root == rank:
|
|
assert data is not None, 'need to pass in data when broadcasting'
|
|
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
|
|
length.value = len(s)
|
|
# run first broadcast
|
|
_LIB.RabitBroadcast(ctypes.byref(length),
|
|
ctypes.sizeof(ctypes.c_ulong), root)
|
|
if root != rank:
|
|
dptr = (ctypes.c_char * length.value)()
|
|
# run second
|
|
_LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
|
|
length.value, root)
|
|
data = pickle.loads(dptr.raw)
|
|
del dptr
|
|
else:
|
|
_LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
|
|
length.value, root)
|
|
del s
|
|
return data
|
|
|
|
# enumeration of dtypes
|
|
DTYPE_ENUM__ = {
|
|
np.dtype('int8') : 0,
|
|
np.dtype('uint8') : 1,
|
|
np.dtype('int32') : 2,
|
|
np.dtype('uint32') : 3,
|
|
np.dtype('int64') : 4,
|
|
np.dtype('uint64') : 5,
|
|
np.dtype('float32') : 6,
|
|
np.dtype('float64') : 7
|
|
}
|
|
|
|
def allreduce(data, op, prepare_fun=None):
|
|
"""Perform allreduce, return the result.
|
|
|
|
Parameters
|
|
----------
|
|
data: numpy array
|
|
Input data.
|
|
op: int
|
|
Reduction operators, can be MIN, MAX, SUM, BITOR
|
|
prepare_fun: function
|
|
Lazy preprocessing function, if it is not None, prepare_fun(data)
|
|
will be called by the function before performing allreduce, to intialize the data
|
|
If the result of Allreduce can be recovered directly,
|
|
then prepare_fun will NOT be called
|
|
|
|
Returns
|
|
-------
|
|
result : array_like
|
|
The result of allreduce, have same shape as data
|
|
|
|
Notes
|
|
-----
|
|
This function is not thread-safe.
|
|
"""
|
|
if not isinstance(data, np.ndarray):
|
|
raise Exception('allreduce only takes in numpy.ndarray')
|
|
buf = data.ravel()
|
|
if buf.base is data.base:
|
|
buf = buf.copy()
|
|
if buf.dtype not in DTYPE_ENUM__:
|
|
raise Exception('data type %s not supported' % str(buf.dtype))
|
|
if prepare_fun is None:
|
|
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
|
buf.size, DTYPE_ENUM__[buf.dtype],
|
|
op, None, None)
|
|
else:
|
|
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
|
def pfunc(args):
|
|
"""prepare function."""
|
|
prepare_fun(data)
|
|
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
|
buf.size, DTYPE_ENUM__[buf.dtype],
|
|
op, func_ptr(pfunc), None)
|
|
return buf
|
|
|
|
|
|
def _load_model(ptr, length):
|
|
"""
|
|
Internal function used by the module,
|
|
unpickle a model from a buffer specified by ptr, length
|
|
Arguments:
|
|
ptr: ctypes.POINTER(ctypes._char)
|
|
pointer to the memory region of buffer
|
|
length: int
|
|
the length of buffer
|
|
"""
|
|
data = (ctypes.c_char * length).from_address(ctypes.addressof(ptr.contents))
|
|
return pickle.loads(data.raw)
|
|
|
|
def load_checkpoint(with_local=False):
|
|
"""Load latest check point.
|
|
|
|
Parameters
|
|
----------
|
|
with_local: bool, optional
|
|
whether the checkpoint contains local model
|
|
|
|
Returns
|
|
-------
|
|
tuple : tuple
|
|
if with_local: return (version, gobal_model, local_model)
|
|
else return (version, gobal_model)
|
|
if returned version == 0, this means no model has been CheckPointed
|
|
and global_model, local_model returned will be None
|
|
"""
|
|
gptr = ctypes.POINTER(ctypes.c_char)()
|
|
global_len = ctypes.c_ulong()
|
|
if with_local:
|
|
lptr = ctypes.POINTER(ctypes.c_char)()
|
|
local_len = ctypes.c_ulong()
|
|
version = _LIB.RabitLoadCheckPoint(
|
|
ctypes.byref(gptr),
|
|
ctypes.byref(global_len),
|
|
ctypes.byref(lptr),
|
|
ctypes.byref(local_len))
|
|
if version == 0:
|
|
return (version, None, None)
|
|
return (version,
|
|
_load_model(gptr, global_len.value),
|
|
_load_model(lptr, local_len.value))
|
|
else:
|
|
version = _LIB.RabitLoadCheckPoint(
|
|
ctypes.byref(gptr),
|
|
ctypes.byref(global_len),
|
|
None, None)
|
|
if version == 0:
|
|
return (version, None)
|
|
return (version,
|
|
_load_model(gptr, global_len.value))
|
|
|
|
def checkpoint(global_model, local_model=None):
|
|
"""Checkpoint the model.
|
|
|
|
This means we finished a stage of execution.
|
|
Every time we call check point, there is a version number which will increase by one.
|
|
|
|
Parameters
|
|
----------
|
|
global_model: anytype that can be pickled
|
|
globally shared model/state when calling this function,
|
|
the caller need to gauranttees that global_model is the same in all nodes
|
|
|
|
local_model: anytype that can be pickled
|
|
Local model, that is specific to current node/rank.
|
|
This can be None when no local state is needed.
|
|
|
|
Notes
|
|
-----
|
|
local_model requires explicit replication of the model for fault-tolerance.
|
|
This will bring replication cost in checkpoint function.
|
|
while global_model do not need explicit replication.
|
|
It is recommended to use global_model if possible.
|
|
"""
|
|
sglobal = pickle.dumps(global_model)
|
|
if local_model is None:
|
|
_LIB.RabitCheckPoint(sglobal, len(sglobal), None, 0)
|
|
del sglobal
|
|
else:
|
|
slocal = pickle.dumps(local_model)
|
|
_LIB.RabitCheckPoint(sglobal, len(sglobal), slocal, len(slocal))
|
|
del slocal
|
|
del sglobal
|
|
|
|
def version_number():
|
|
"""Returns version number of current stored model.
|
|
|
|
This means how many calls to CheckPoint we made so far.
|
|
|
|
Returns
|
|
-------
|
|
version : int
|
|
Version number of currently stored model
|
|
"""
|
|
ret = _LIB.RabitVersionNumber()
|
|
return ret
|