add auto caching of python in hadoop script, mock test module to python, with checkpt
This commit is contained in:
@@ -7,18 +7,34 @@ import cPickle as pickle
|
||||
import ctypes
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
if os.name == 'nt':
|
||||
assert False, "Rabit windows is not yet compiled"
|
||||
else:
|
||||
RABIT_PATH = os.path.dirname(__file__)+'/librabit_wrapper.so'
|
||||
WRAPPER_PATH = os.path.dirname(__file__)
|
||||
|
||||
rbtlib = None
|
||||
|
||||
# load in xgboost library
|
||||
rbtlib = ctypes.cdll.LoadLibrary(RABIT_PATH)
|
||||
rbtlib.RabitGetRank.restype = ctypes.c_int
|
||||
rbtlib.RabitGetWorldSize.restype = ctypes.c_int
|
||||
rbtlib.RabitVersionNumber.restype = ctypes.c_int
|
||||
def loadlib__(with_mock = False):
|
||||
global rbtlib
|
||||
if rbtlib != None:
|
||||
warnings.Warn('rabit.int call was ignored because it has already been initialized', level = 2)
|
||||
return
|
||||
if with_mock:
|
||||
rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH + '/librabit_wrapper_mock.so')
|
||||
else:
|
||||
rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH + '/librabit_wrapper.so')
|
||||
rbtlib.RabitGetRank.restype = ctypes.c_int
|
||||
rbtlib.RabitGetWorldSize.restype = ctypes.c_int
|
||||
rbtlib.RabitVersionNumber.restype = ctypes.c_int
|
||||
|
||||
def unloadlib__():
|
||||
global rbtlib
|
||||
del rbtlib
|
||||
rbtlib = None
|
||||
|
||||
# reduction operators
|
||||
MAX = 0
|
||||
@@ -32,14 +48,17 @@ def check_err__():
|
||||
"""
|
||||
return
|
||||
|
||||
def init(args = sys.argv):
|
||||
def init(args = sys.argv, with_mock = False):
|
||||
"""
|
||||
intialize the rabit module, call this once before using anything
|
||||
Arguments:
|
||||
args: list(string)
|
||||
the list of arguments used to initialized the rabit
|
||||
usually you need to pass in sys.argv
|
||||
args: list(string) [default=sys.argv]
|
||||
the list of arguments used to initialized the rabit
|
||||
usually you need to pass in sys.argv
|
||||
with_mock: boolean [default=False]
|
||||
Whether initialize the mock test module
|
||||
"""
|
||||
loadlib__(with_mock)
|
||||
arr = (ctypes.c_char_p * len(args))()
|
||||
arr[:] = args
|
||||
rbtlib.RabitInit(len(args), arr)
|
||||
@@ -51,6 +70,7 @@ def finalize():
|
||||
"""
|
||||
rbtlib.RabitFinalize()
|
||||
check_err__()
|
||||
unloadlib__()
|
||||
|
||||
def get_rank():
|
||||
"""
|
||||
@@ -156,7 +176,7 @@ DTYPE_ENUM__ = {
|
||||
np.dtype('float64') : 7
|
||||
}
|
||||
|
||||
def allreduce(data, op):
|
||||
def allreduce(data, op, prepare_fun = None):
|
||||
"""
|
||||
perform allreduce, return the result, this function is not thread-safe
|
||||
Arguments:
|
||||
@@ -164,8 +184,8 @@ def allreduce(data, op):
|
||||
input data
|
||||
op: int
|
||||
reduction operators, can be MIN, MAX, SUM, BITOR
|
||||
prepare_fun: lambda
|
||||
Lazy preprocessing function, if it is not None, prepare_fun()
|
||||
prepare_fun: lambda data
|
||||
Lazy preprocessing function, if it is not None, prepare_fun(data)
|
||||
will be called by the function before performing allreduce, to intialize the data
|
||||
If the result of Allreduce can be recovered directly, then prepare_fun will NOT be called
|
||||
Returns:
|
||||
@@ -178,9 +198,17 @@ def allreduce(data, op):
|
||||
buf = buf.copy()
|
||||
if buf.dtype not in DTYPE_ENUM__:
|
||||
raise Exception('data type %s not supported' % str(buf.dtype))
|
||||
rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||
buf.size, DTYPE_ENUM__[dtype],
|
||||
op, None, None);
|
||||
if prepare_fun is None:
|
||||
rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||
op, None, None)
|
||||
else:
|
||||
PFUNC = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
||||
def pfunc(args):
|
||||
prepare_fun(data)
|
||||
rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||
op, PFUNC(pfunc), None)
|
||||
check_err__()
|
||||
return buf
|
||||
|
||||
@@ -273,4 +301,3 @@ def version_number():
|
||||
ret = rbtlib.RabitVersionNumber()
|
||||
check_err__()
|
||||
return ret
|
||||
|
||||
|
||||
Reference in New Issue
Block a user