This commit is contained in:
tqchen 2015-01-13 12:51:55 -08:00
parent 15e085cd32
commit 877fc42e40
3 changed files with 238 additions and 30 deletions

View File

@ -18,6 +18,7 @@ else:
rbtlib = ctypes.cdll.LoadLibrary(RABIT_PATH) rbtlib = ctypes.cdll.LoadLibrary(RABIT_PATH)
rbtlib.RabitGetRank.restype = ctypes.c_int rbtlib.RabitGetRank.restype = ctypes.c_int
rbtlib.RabitGetWorldSize.restype = ctypes.c_int rbtlib.RabitGetWorldSize.restype = ctypes.c_int
rbtlib.RabitVersionNumber.restype = ctypes.c_int
# reduction operators # reduction operators
MAX = 0 MAX = 0
@ -111,7 +112,7 @@ def broadcast(data, root):
Arguments: Arguments:
data: anytype that can be pickled data: anytype that can be pickled
input data, if current rank does not equal root, this can be None input data, if current rank does not equal root, this can be None
root: int root: int
rank of the node to broadcast data from rank of the node to broadcast data from
Returns: Returns:
@ -123,8 +124,6 @@ def broadcast(data, root):
assert data is not None, 'need to pass in data when broadcasting' assert data is not None, 'need to pass in data when broadcasting'
s = pickle.dumps(data, protocol = pickle.HIGHEST_PROTOCOL) s = pickle.dumps(data, protocol = pickle.HIGHEST_PROTOCOL)
length.value = len(s) length.value = len(s)
dptr = (ctypes.c_char * length.value)()
dptr[:] = s
# run first broadcast # run first broadcast
rbtlib.RabitBroadcast(ctypes.byref(length), rbtlib.RabitBroadcast(ctypes.byref(length),
ctypes.sizeof(ctypes.c_ulong), ctypes.sizeof(ctypes.c_ulong),
@ -132,20 +131,40 @@ def broadcast(data, root):
check_err__() check_err__()
if root != rank: if root != rank:
dptr = (ctypes.c_char * length.value)() dptr = (ctypes.c_char * length.value)()
# run second # run second
rbtlib.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p), rbtlib.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
length.value, root) length.value, root)
check_err__() check_err__()
return pickle.loads(dptr.value) data = pickle.loads(dptr.raw)
del dptr
else:
rbtlib.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
length.value, root)
check_err__()
del s
return data
def allreduce(data, op, prepare_fun = None): # enumeration of dtypes
DTYPE_ENUM__ = {
np.dtype('int8') : 0,
np.dtype('uint8') : 1,
np.dtype('int32') : 2,
np.dtype('uint32') : 3,
np.dtype('int64') : 4,
np.dtype('uint64') : 5,
np.dtype('float32') : 6,
np.dtype('float64') : 7
}
def allreduce(data, op):
""" """
perform allreduce, return the result, this function is not thread-safe perform allreduce, return the result, this function is not thread-safe
Arguments: Arguments:
data: numpy ndarray data: numpy ndarray
input data input data
op: reduction operators, can be MIN, MAX, SUM, BITOR op: int
prepare_fun: lambda : reduction operators, can be MIN, MAX, SUM, BITOR
prepare_fun: lambda
Lazy preprocessing function, if it is not None, prepare_fun() Lazy preprocessing function, if it is not None, prepare_fun()
will be called by the function before performing allreduce, to intialize the data will be called by the function before performing allreduce, to intialize the data
If the result of Allreduce can be recovered directly, then prepare_fun will NOT be called If the result of Allreduce can be recovered directly, then prepare_fun will NOT be called
@ -157,25 +176,101 @@ def allreduce(data, op, prepare_fun = None):
buf = data.ravel() buf = data.ravel()
if buf.base is data.base: if buf.base is data.base:
buf = buf.copy() buf = buf.copy()
if buf.dtype is np.dtype('int8'): if buf.dtype not in DTYPE_ENUM__:
dtype = 0
elif buf.dtype is np.dtype('uint8'):
dtype = 1
elif buf.dtype is np.dtype('int32'):
dtype = 2
elif buf.dtype is np.dtype('uint32'):
dtype = 3
elif buf.dtype is np.dtype('int64'):
dtype = 4
elif buf.dtype is np.dtype('uint64'):
dtype = 5
elif buf.dtype is np.dtype('float32'):
dtype = 6
elif buf.dtype is np.dtype('float64'):
dtype = 7
else:
raise Exception('data type %s not supported' % str(buf.dtype)) raise Exception('data type %s not supported' % str(buf.dtype))
rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, dtype, op, None, None); buf.size, DTYPE_ENUM__[dtype],
op, None, None);
check_err__() check_err__()
return buf return buf
def load_model__(ptr, length):
"""
Internal function used by the module,
unpickle a model from a buffer specified by ptr, length
Arguments:
ptr: ctypes.POINTER(ctypes._char)
pointer to the memory region of buffer
length: int
the length of buffer
"""
data = (ctypes.c_char * length).from_address(addressof(ptr.contents))
return pickle.loads(data.raw)
def load_checkpoint(with_local = False):
"""
load latest check point
Arguments:
with_local: boolean [default = False]
whether the checkpoint contains local model
Returns:
if with_local: return (version, gobal_model, local_model)
else return (version, gobal_model)
if returned version == 0, this means no model has been CheckPointed
and global_model, local_model returned will be None
"""
gp = ctypes.POINTER(ctypes.c_char)()
global_len = ctypes.c_ulong()
if with_local:
lp = ctypes.POINTER(ctypes.c_char)()
local_len = ctypes.c_ulong()
version = rbtlib.RabitLoadCheckPoint(
ctypes.byref(gp),
ctypes.byref(global_len),
ctypes.byref(lp),
ctypes.byref(local_len))
check_err__()
if version == 0:
return (version, None, None)
return (version,
load_model__(gp, global_len.value),
load_model__(lp, local_len.value))
else:
version = rbtlib.RabitLoadCheckPoint(
ctypes.byref(gp),
ctypes.byref(global_len),
None, None)
check_err__()
if version == 0:
return (version, None)
return (version,
load_model__(gp, global_len.value))
def checkpoint(global_model, local_model = None):
"""
checkpoint the model, meaning we finished a stage of execution
every time we call check point, there is a version number which will increase by one
Arguments:
global_model: anytype that can be pickled
globally shared model/state when calling this function,
the caller need to gauranttees that global_model is the same in all nodes
local_model: anytype that can be pickled
local model, that is specific to current node/rank.
This can be None when no local state is needed.
local_model requires explicit replication of the model for fault-tolerance,
which will bring replication cost in checkpoint function,
while global_model do not need explicit replication.
It is recommended to use global_model if possible
"""
sg = pickle.dumps(global_model)
if local_model is None:
rbtlib.RabitCheckPoint(sg, len(sg), None, 0)
check_err__()
del sg;
else:
sl = pickle.dumps(local_model)
rbtlib.RabitCheckPoint(sg, len(sg), sl, len(sl))
check_err__()
del sl; del sg;
def version_number():
"""
Returns version number of current stored model,
which means how many calls to CheckPoint we made so far
"""
ret = rbtlib.RabitVersionNumber()
check_err__()
return ret

View File

@ -28,7 +28,7 @@ struct FHelper<op::BitOR, DType> {
void (*prepare_fun)(void *arg), void (*prepare_fun)(void *arg),
void *prepare_arg) { void *prepare_arg) {
utils::Error("DataType does not support bitwise or operation"); utils::Error("DataType does not support bitwise or operation");
} }
}; };
template<typename OP> template<typename OP>
inline void Allreduce_(void *sendrecvbuf_, inline void Allreduce_(void *sendrecvbuf_,
@ -116,6 +116,43 @@ inline void Allreduce(void *sendrecvbuf,
default: utils::Error("unknown enum_op"); default: utils::Error("unknown enum_op");
} }
} }
// temporal memory for global and local model
std::string global_buffer, local_buffer;
// wrapper for serialization
struct ReadWrapper : public ISerializable {
std::string *p_str;
explicit ReadWrapper(std::string *p_str)
: p_str(p_str) {}
virtual void Load(IStream &fi) {
uint64_t sz;
utils::Assert(fi.Read(&sz, sizeof(sz)) != 0,
"Read pickle string");
p_str->resize(sz);
if (sz != 0) {
utils::Assert(fi.Read(&(*p_str)[0], sizeof(char) * sz) != 0,
"Read pickle string");
}
}
virtual void Save(IStream &fo) const {
utils::Error("not implemented");
}
};
struct WriteWrapper : public ISerializable {
const char *data;
size_t length;
explicit WriteWrapper(const char *data,
size_t length)
: data(data), length(length) {
}
virtual void Load(IStream &fi) {
utils::Error("not implemented");
}
virtual void Save(IStream &fo) const {
uint64_t sz = static_cast<uint16_t>(length);
fo.Write(&sz, sizeof(sz));
fo.Write(data, length * sizeof(char));
}
};
} // namespace wrapper } // namespace wrapper
} // namespace rabit } // namespace rabit
extern "C" { extern "C" {
@ -161,4 +198,42 @@ extern "C" {
static_cast<rabit::engine::mpi::OpType>(enum_op), static_cast<rabit::engine::mpi::OpType>(enum_op),
prepare_fun, prepare_arg); prepare_fun, prepare_arg);
} }
int RabitLoadCheckPoint(char **out_global_model,
rbt_ulong *out_global_len,
char **out_local_model,
rbt_ulong *out_local_len) {
using rabit::BeginPtr;
using namespace rabit::wrapper;
ReadWrapper sg(&global_buffer);
ReadWrapper sl(&local_buffer);
int version;
if (out_local_model == NULL) {
version = rabit::LoadCheckPoint(&sg, NULL);
*out_global_model = BeginPtr(global_buffer);
*out_global_len = static_cast<rbt_ulong>(global_buffer.length());
} else {
version = rabit::LoadCheckPoint(&sg, &sl);
*out_global_model = BeginPtr(global_buffer);
*out_global_len = static_cast<rbt_ulong>(global_buffer.length());
*out_local_model = BeginPtr(local_buffer);
*out_local_len = static_cast<rbt_ulong>(local_buffer.length());
}
return version;
}
void RabitCheckPoint(const char *global_model,
rbt_ulong global_len,
const char *local_model,
rbt_ulong local_len) {
using namespace rabit::wrapper;
WriteWrapper sg(global_model, global_len);
WriteWrapper sl(local_model, local_len);
if (local_model == NULL) {
rabit::CheckPoint(&sg, NULL);
} else {
rabit::CheckPoint(&sg, &sl);
}
}
int RabitVersionNumber(void) {
return rabit::VersionNumber();
}
} }

View File

@ -81,6 +81,44 @@ extern "C" {
int enum_op, int enum_op,
void (*prepare_fun)(void *arg), void (*prepare_fun)(void *arg),
void *prepare_arg); void *prepare_arg);
/*!
* \brief load latest check point
* \param out_global_model hold output of serialized global_model
* \param out_global_len the output length of serialized global model
* \param out_local_model hold output of serialized local_model, can be NULL
* \param out_local_len the output length of serialized local model, can be NULL
*
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* nothing will be touched
*/
RABIT_DLL int RabitLoadCheckPoint(char **out_global_model,
rbt_ulong *out_global_len,
char **out_local_model,
rbt_ulong *out_local_len);
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param global_model hold content of serialized global_model
* \param global_len the content length of serialized global model
* \param local_model hold content of serialized local_model, can be NULL
* \param local_len the content length of serialized local model, can be NULL
*
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
* So only CheckPoint with global_model if possible
*/
RABIT_DLL void RabitCheckPoint(const char *global_model,
rbt_ulong global_len,
const char *local_model,
rbt_ulong local_len);
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far
*/
RABIT_DLL int RabitVersionNumber(void);
#ifdef __cplusplus #ifdef __cplusplus
} // C } // C
#endif #endif