checkin broadcast python module

This commit is contained in:
tqchen 2015-01-12 22:32:13 -08:00
parent 9a4a81f100
commit 2d72c853df
2 changed files with 128 additions and 1 deletions

View File

@ -47,4 +47,5 @@ $(SLIB) :
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^)
clean:
$(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~
$(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~ wrapper/*~

126
wrapper/rabit.py Normal file
View File

@ -0,0 +1,126 @@
"""
Python interface for rabit
Reliable Allreduce and Broadcast Library
Author: Tianqi Chen
"""
import cPickle as pickle
import ctypes
import os
import sys
if os.name == 'nt':
assert False, "Rabit windows is not yet compiled"
else:
RABIT_PATH = os.path.dirname(__file__)+'/librabit_wrapper.so'
# load in xgboost library
rbtlib = ctypes.cdll.LoadLibrary(RABIT_PATH)
rbtlib.RabitGetRank.restype = ctypes.c_int
rbtlib.RabitGetWorldSize.restype = ctypes.c_int
def check_err__():
"""
reserved function used to check error
"""
return
def init(args = sys.argv):
"""
intialize the rabit module, call this once before using anything
Arguments:
args: list(string)
the list of arguments used to initialized the rabit
usually you need to pass in sys.argv
"""
arr = (ctypes.c_char_p * len(args))()
arr[:] = args
rbtlib.RabitInit(len(args), arr)
check_err__()
def finalize():
"""
finalize the rabit engine, call this function after you finished all jobs
"""
rbtlib.RabitFinalize()
check_err__()
def get_rank():
"""
Returns rank of current process
"""
ret = rbtlib.RabitGetRank()
check_err__()
return ret
def get_world_size():
"""
Returns get total number of process
"""
ret = rbtlib.RabitGetWorlSize()
check_err__()
return ret
def tracker_print(msg):
"""
print message to the tracker
this function can be used to communicate the information of the progress
to the tracker
"""
if not isinstance(msg, str):
msg = str(msg)
rbtlib.RabitTrackerPrint(ctypes.c_char_p(msg).encode('utf-8'))
check_err__()
def get_processor_name():
"""
Returns the name of processor(host)
"""
mxlen = 256
length = ctypes.c_ulong()
buf = ctypes.create_string_buffer(mxlen)
rbtlib.RabitGetProcessorName(buf, ctypes.byref(length),
mxlen)
check_err__()
return buf.value
def broadcast(data, root):
"""
broadcast object from one node to all other nodes
this function will return the broadcasted object
Example: the following example broadcast hello from rank 0 to all other nodes
```python
import rabit
rabit.init()
if rabit.get_rank() == 0:
res = rabit.broadcast('hello', 0)
else:
res = rabit.broadcast(None, 0)
rabit.finalize()
```
Arguments:
data: anytype that can be pickled
input data, if current rank does not equal root, this can be None
root: int
rank of the node to broadcast data from
"""
rank = get_rank()
length = ctypes.c_ulong()
if root == rank:
assert data is not None, 'need to pass in data when broadcasting'
s = pickle.dumps(data, protocol = pickle.HIGHEST_PROTOCOL)
length.value = len(s)
dptr = (ctypes.c_char * length.value)()
dptr[:] = s
# run first broadcast
rbtlib.RabitBroadcast(ctypes.byref(length),
ctypes.sizeof(ctypes.c_ulong),
root)
check_err__()
if root != rank:
dptr = (ctypes.c_char * length.value)()
# run second
rbtlib.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
length.value, root)
check_err__()
return pickle.loads(dptr.value)