xgboost/wrapper/rabit.py
tqchen a16289b204 Squashed 'subtree/rabit/' changes from fa99857..e81a11d
e81a11d Merge pull request #25 from daiyl0320/master
35c3b37 add retry mechanism to ConnectTracker and modify Listen backlog to 128 in rabit_traker.py
c71ed6f try deply doxygen
62e5647 try deply doxygen
732f1c6 try
2fa6e02 ok
0537665 minor
7b59dcb minor
5934950 new doc
f538187 ok
44b6049 new doc
387339b add more
9d4397a chg
2879a48 chg
30e3110 ok
9ff0301 add link translation
6b629c2 k
32e1955 ok
8f4839d fix
93137b2 ok
7eeeb79 reload recommonmark
a8f00cc minor
19b0f01 ok
dd01184 minor
c1cdc19 minor
fcf0f43 try rst
cbc21ae try
62ddfa7 tiny
aefc05c final change
2aee9b4 minor
fe4e7c2 ok
8001983 change to subtitle
5ca33e4 ok
88f7d24 update guide
29d43ab add code
fe8bb3b minor hack for readthedocs
229c71d Merge branch 'master' of ssh://github.com/dmlc/rabit
7424218 ok
d1d45bb Update README.md
1e8813f Update README.md
1ccc990 Update README.md
0323e06 remove readme
679a835 remove theme
7ea5b7c remove numpydoc to napoleon
b73e2be Merge branch 'master' of ssh://github.com/dmlc/rabit
1742283 ok
1838e25 Update python-requirements.txt
bc4e957 ok
fba6fc2 ok
0251101 ok
d50b905 ok
d4f2509 ok
cdf401a ok
fef0ef2 new doc
cef360d ok
c125d2a ok
270a49e add requirments
744f901 get the basic doc
1cb5cad Merge branch 'master' of ssh://github.com/dmlc/rabit
8cc07ba minor
d74f126 Update .travis.yml
52b3dcd Update .travis.yml
099581b Update .travis.yml
1258046 Update .travis.yml
7addac9 Update Makefile
0ea7adf Update .travis.yml
f858856 Update travis_script.sh
d8eac4a Update README.md
3cc49ad lint and travis
ceedf4e fix
fd8920c fix win32
8bbed35 modify
9520b90 Merge pull request #14 from dmlc/hjk41
df14bb1 fix type
f441dc7 replace tab with blankspace
2467942 remove unnecessary include
181ef47 defined long long and ulonglong
1582180 use int32_t to define int and int64_t to define long. in VC long is 32bit
e0b7da0 fix

git-subtree-dir: subtree/rabit
git-subtree-split: e81a11dd7ee3cff87a38a42901315821df018bae
2015-10-20 19:37:47 -07:00

328 lines
9.2 KiB
Python

"""
Reliable Allreduce and Broadcast Library.
Author: Tianqi Chen
"""
# pylint: disable=unused-argument,invalid-name,global-statement,dangerous-default-value,
import cPickle as pickle
import ctypes
import os
import sys
import warnings
import numpy as np
# version information about the doc
__version__ = '1.0'
if os.name == 'nt':
WRAPPER_PATH = os.path.dirname(__file__) + '\\..\\windows\\x64\\Release\\rabit_wrapper%s.dll'
else:
WRAPPER_PATH = os.path.dirname(__file__) + '/librabit_wrapper%s.so'
_LIB = None
# load in xgboost library
def _loadlib(lib='standard'):
"""Load rabit library."""
global _LIB
if _LIB != None:
warnings.warn('rabit.int call was ignored because it has'\
' already been initialized', level=2)
return
if lib == 'standard':
_LIB = ctypes.cdll.LoadLibrary(WRAPPER_PATH % '')
elif lib == 'mock':
_LIB = ctypes.cdll.LoadLibrary(WRAPPER_PATH % '_mock')
elif lib == 'mpi':
_LIB = ctypes.cdll.LoadLibrary(WRAPPER_PATH % '_mpi')
else:
raise Exception('unknown rabit lib %s, can be standard, mock, mpi' % lib)
_LIB.RabitGetRank.restype = ctypes.c_int
_LIB.RabitGetWorldSize.restype = ctypes.c_int
_LIB.RabitVersionNumber.restype = ctypes.c_int
def _unloadlib():
"""Unload rabit library."""
global _LIB
del _LIB
_LIB = None
# reduction operators
MAX = 0
MIN = 1
SUM = 2
BITOR = 3
def init(args=None, lib='standard'):
"""Intialize the rabit module, call this once before using anything.
Parameters
----------
args: list of str, optional
The list of arguments used to initialized the rabit
usually you need to pass in sys.argv.
Defaults to sys.argv when it is None.
lib: {'standard', 'mock', 'mpi'}
Type of library we want to load
"""
if args is None:
args = sys.argv
_loadlib(lib)
arr = (ctypes.c_char_p * len(args))()
arr[:] = args
_LIB.RabitInit(len(args), arr)
def finalize():
"""Finalize the rabit engine.
Call this function after you finished all jobs.
"""
_LIB.RabitFinalize()
_unloadlib()
def get_rank():
"""Get rank of current process.
Returns
-------
rank : int
Rank of current process.
"""
ret = _LIB.RabitGetRank()
return ret
def get_world_size():
"""Get total number workers.
Returns
-------
n : int
Total number of process.
"""
ret = _LIB.RabitGetWorldSize()
return ret
def tracker_print(msg):
"""Print message to the tracker.
This function can be used to communicate the information of
the progress to the tracker
Parameters
----------
msg : str
The message to be printed to tracker.
"""
if not isinstance(msg, str):
msg = str(msg)
_LIB.RabitTrackerPrint(ctypes.c_char_p(msg).encode('utf-8'))
def get_processor_name():
"""Get the processor name.
Returns
-------
name : str
the name of processor(host)
"""
mxlen = 256
length = ctypes.c_ulong()
buf = ctypes.create_string_buffer(mxlen)
_LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen)
return buf.value
def broadcast(data, root):
"""Broadcast object from one node to all other nodes.
Parameters
----------
data : any type that can be pickled
Input data, if current rank does not equal root, this can be None
root : int
Rank of the node to broadcast data from.
Returns
-------
object : int
the result of broadcast.
"""
rank = get_rank()
length = ctypes.c_ulong()
if root == rank:
assert data is not None, 'need to pass in data when broadcasting'
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
length.value = len(s)
# run first broadcast
_LIB.RabitBroadcast(ctypes.byref(length),
ctypes.sizeof(ctypes.c_ulong), root)
if root != rank:
dptr = (ctypes.c_char * length.value)()
# run second
_LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
length.value, root)
data = pickle.loads(dptr.raw)
del dptr
else:
_LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
length.value, root)
del s
return data
# enumeration of dtypes
DTYPE_ENUM__ = {
np.dtype('int8') : 0,
np.dtype('uint8') : 1,
np.dtype('int32') : 2,
np.dtype('uint32') : 3,
np.dtype('int64') : 4,
np.dtype('uint64') : 5,
np.dtype('float32') : 6,
np.dtype('float64') : 7
}
def allreduce(data, op, prepare_fun=None):
"""Perform allreduce, return the result.
Parameters
----------
data: numpy array
Input data.
op: int
Reduction operators, can be MIN, MAX, SUM, BITOR
prepare_fun: function
Lazy preprocessing function, if it is not None, prepare_fun(data)
will be called by the function before performing allreduce, to intialize the data
If the result of Allreduce can be recovered directly,
then prepare_fun will NOT be called
Returns
-------
result : array_like
The result of allreduce, have same shape as data
Notes
-----
This function is not thread-safe.
"""
if not isinstance(data, np.ndarray):
raise Exception('allreduce only takes in numpy.ndarray')
buf = data.ravel()
if buf.base is data.base:
buf = buf.copy()
if buf.dtype not in DTYPE_ENUM__:
raise Exception('data type %s not supported' % str(buf.dtype))
if prepare_fun is None:
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype],
op, None, None)
else:
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
def pfunc(args):
"""prepare function."""
prepare_fun(data)
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype],
op, func_ptr(pfunc), None)
return buf
def _load_model(ptr, length):
"""
Internal function used by the module,
unpickle a model from a buffer specified by ptr, length
Arguments:
ptr: ctypes.POINTER(ctypes._char)
pointer to the memory region of buffer
length: int
the length of buffer
"""
data = (ctypes.c_char * length).from_address(ctypes.addressof(ptr.contents))
return pickle.loads(data.raw)
def load_checkpoint(with_local=False):
"""Load latest check point.
Parameters
----------
with_local: bool, optional
whether the checkpoint contains local model
Returns
-------
tuple : tuple
if with_local: return (version, gobal_model, local_model)
else return (version, gobal_model)
if returned version == 0, this means no model has been CheckPointed
and global_model, local_model returned will be None
"""
gptr = ctypes.POINTER(ctypes.c_char)()
global_len = ctypes.c_ulong()
if with_local:
lptr = ctypes.POINTER(ctypes.c_char)()
local_len = ctypes.c_ulong()
version = _LIB.RabitLoadCheckPoint(
ctypes.byref(gptr),
ctypes.byref(global_len),
ctypes.byref(lptr),
ctypes.byref(local_len))
if version == 0:
return (version, None, None)
return (version,
_load_model(gptr, global_len.value),
_load_model(lptr, local_len.value))
else:
version = _LIB.RabitLoadCheckPoint(
ctypes.byref(gptr),
ctypes.byref(global_len),
None, None)
if version == 0:
return (version, None)
return (version,
_load_model(gptr, global_len.value))
def checkpoint(global_model, local_model=None):
"""Checkpoint the model.
This means we finished a stage of execution.
Every time we call check point, there is a version number which will increase by one.
Parameters
----------
global_model: anytype that can be pickled
globally shared model/state when calling this function,
the caller need to gauranttees that global_model is the same in all nodes
local_model: anytype that can be pickled
Local model, that is specific to current node/rank.
This can be None when no local state is needed.
Notes
-----
local_model requires explicit replication of the model for fault-tolerance.
This will bring replication cost in checkpoint function.
while global_model do not need explicit replication.
It is recommended to use global_model if possible.
"""
sglobal = pickle.dumps(global_model)
if local_model is None:
_LIB.RabitCheckPoint(sglobal, len(sglobal), None, 0)
del sglobal
else:
slocal = pickle.dumps(local_model)
_LIB.RabitCheckPoint(sglobal, len(sglobal), slocal, len(slocal))
del slocal
del sglobal
def version_number():
"""Returns version number of current stored model.
This means how many calls to CheckPoint we made so far.
Returns
-------
version : int
Version number of currently stored model
"""
ret = _LIB.RabitVersionNumber()
return ret