add auto caching of python in hadoop script, mock test module to python, with checkpt

This commit is contained in:
tqchen
2015-01-13 14:29:10 -08:00
parent 877fc42e40
commit 3419cf9aa7
10 changed files with 104 additions and 144 deletions

View File

@@ -7,18 +7,34 @@ import cPickle as pickle
import ctypes
import os
import sys
import warnings
import numpy as np
if os.name == 'nt':
assert False, "Rabit windows is not yet compiled"
else:
RABIT_PATH = os.path.dirname(__file__)+'/librabit_wrapper.so'
WRAPPER_PATH = os.path.dirname(__file__)
rbtlib = None
# load in xgboost library
rbtlib = ctypes.cdll.LoadLibrary(RABIT_PATH)
rbtlib.RabitGetRank.restype = ctypes.c_int
rbtlib.RabitGetWorldSize.restype = ctypes.c_int
rbtlib.RabitVersionNumber.restype = ctypes.c_int
def loadlib__(with_mock = False):
global rbtlib
if rbtlib != None:
warnings.Warn('rabit.int call was ignored because it has already been initialized', level = 2)
return
if with_mock:
rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH + '/librabit_wrapper_mock.so')
else:
rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH + '/librabit_wrapper.so')
rbtlib.RabitGetRank.restype = ctypes.c_int
rbtlib.RabitGetWorldSize.restype = ctypes.c_int
rbtlib.RabitVersionNumber.restype = ctypes.c_int
def unloadlib__():
global rbtlib
del rbtlib
rbtlib = None
# reduction operators
MAX = 0
@@ -32,14 +48,17 @@ def check_err__():
"""
return
def init(args = sys.argv):
def init(args = sys.argv, with_mock = False):
"""
intialize the rabit module, call this once before using anything
Arguments:
args: list(string)
the list of arguments used to initialized the rabit
usually you need to pass in sys.argv
args: list(string) [default=sys.argv]
the list of arguments used to initialized the rabit
usually you need to pass in sys.argv
with_mock: boolean [default=False]
Whether initialize the mock test module
"""
loadlib__(with_mock)
arr = (ctypes.c_char_p * len(args))()
arr[:] = args
rbtlib.RabitInit(len(args), arr)
@@ -51,6 +70,7 @@ def finalize():
"""
rbtlib.RabitFinalize()
check_err__()
unloadlib__()
def get_rank():
"""
@@ -156,7 +176,7 @@ DTYPE_ENUM__ = {
np.dtype('float64') : 7
}
def allreduce(data, op):
def allreduce(data, op, prepare_fun = None):
"""
perform allreduce, return the result, this function is not thread-safe
Arguments:
@@ -164,8 +184,8 @@ def allreduce(data, op):
input data
op: int
reduction operators, can be MIN, MAX, SUM, BITOR
prepare_fun: lambda
Lazy preprocessing function, if it is not None, prepare_fun()
prepare_fun: lambda data
Lazy preprocessing function, if it is not None, prepare_fun(data)
will be called by the function before performing allreduce, to intialize the data
If the result of Allreduce can be recovered directly, then prepare_fun will NOT be called
Returns:
@@ -178,9 +198,17 @@ def allreduce(data, op):
buf = buf.copy()
if buf.dtype not in DTYPE_ENUM__:
raise Exception('data type %s not supported' % str(buf.dtype))
rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[dtype],
op, None, None);
if prepare_fun is None:
rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype],
op, None, None)
else:
PFUNC = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
def pfunc(args):
prepare_fun(data)
rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype],
op, PFUNC(pfunc), None)
check_err__()
return buf
@@ -273,4 +301,3 @@ def version_number():
ret = rbtlib.RabitVersionNumber()
check_err__()
return ret