From 2d72c853df72797c01d2bf7bb320b69e585293c0 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 12 Jan 2015 22:32:13 -0800 Subject: [PATCH] checkin broadcast python module --- Makefile | 3 +- wrapper/rabit.py | 126 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 1 deletion(-) create mode 100644 wrapper/rabit.py diff --git a/Makefile b/Makefile index daa08aa79..3030f3f4d 100644 --- a/Makefile +++ b/Makefile @@ -47,4 +47,5 @@ $(SLIB) : $(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) clean: - $(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~ + $(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~ wrapper/*~ + diff --git a/wrapper/rabit.py b/wrapper/rabit.py new file mode 100644 index 000000000..19e7b2f9d --- /dev/null +++ b/wrapper/rabit.py @@ -0,0 +1,126 @@ +""" +Python interface for rabit + Reliable Allreduce and Broadcast Library +Author: Tianqi Chen +""" +import cPickle as pickle +import ctypes +import os +import sys +if os.name == 'nt': + assert False, "Rabit windows is not yet compiled" +else: + RABIT_PATH = os.path.dirname(__file__)+'/librabit_wrapper.so' + +# load in xgboost library +rbtlib = ctypes.cdll.LoadLibrary(RABIT_PATH) +rbtlib.RabitGetRank.restype = ctypes.c_int +rbtlib.RabitGetWorldSize.restype = ctypes.c_int + +def check_err__(): + """ + reserved function used to check error + """ + return + +def init(args = sys.argv): + """ + intialize the rabit module, call this once before using anything + Arguments: + args: list(string) + the list of arguments used to initialized the rabit + usually you need to pass in sys.argv + """ + arr = (ctypes.c_char_p * len(args))() + arr[:] = args + rbtlib.RabitInit(len(args), arr) + check_err__() + +def finalize(): + """ + finalize the rabit engine, call this function after you finished all jobs + """ + rbtlib.RabitFinalize() + check_err__() + +def get_rank(): + """ + Returns rank of current process + """ + ret = rbtlib.RabitGetRank() + check_err__() + return ret + +def get_world_size(): + """ + Returns get total number of process + """ + ret = rbtlib.RabitGetWorlSize() + check_err__() + return ret + +def tracker_print(msg): + """ + print message to the tracker + this function can be used to communicate the information of the progress + to the tracker + """ + if not isinstance(msg, str): + msg = str(msg) + rbtlib.RabitTrackerPrint(ctypes.c_char_p(msg).encode('utf-8')) + check_err__() + +def get_processor_name(): + """ + Returns the name of processor(host) + """ + mxlen = 256 + length = ctypes.c_ulong() + buf = ctypes.create_string_buffer(mxlen) + rbtlib.RabitGetProcessorName(buf, ctypes.byref(length), + mxlen) + check_err__() + return buf.value + +def broadcast(data, root): + """ + broadcast object from one node to all other nodes + this function will return the broadcasted object + + Example: the following example broadcast hello from rank 0 to all other nodes + ```python + import rabit + rabit.init() + if rabit.get_rank() == 0: + res = rabit.broadcast('hello', 0) + else: + res = rabit.broadcast(None, 0) + rabit.finalize() + ``` + + Arguments: + data: anytype that can be pickled + input data, if current rank does not equal root, this can be None + root: int + rank of the node to broadcast data from + """ + rank = get_rank() + length = ctypes.c_ulong() + if root == rank: + assert data is not None, 'need to pass in data when broadcasting' + s = pickle.dumps(data, protocol = pickle.HIGHEST_PROTOCOL) + length.value = len(s) + dptr = (ctypes.c_char * length.value)() + dptr[:] = s + # run first broadcast + rbtlib.RabitBroadcast(ctypes.byref(length), + ctypes.sizeof(ctypes.c_ulong), + root) + check_err__() + if root != rank: + dptr = (ctypes.c_char * length.value)() + # run second + rbtlib.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p), + length.value, root) + check_err__() + return pickle.loads(dptr.value)