From 877fc42e40b472112cc6c43311e5530ed0040264 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 13 Jan 2015 12:51:55 -0800 Subject: [PATCH] add data --- wrapper/rabit.py | 153 +++++++++++++++++++++++++++++++-------- wrapper/rabit_wrapper.cc | 77 +++++++++++++++++++- wrapper/rabit_wrapper.h | 38 ++++++++++ 3 files changed, 238 insertions(+), 30 deletions(-) diff --git a/wrapper/rabit.py b/wrapper/rabit.py index 5458c339c..2fd0fd745 100644 --- a/wrapper/rabit.py +++ b/wrapper/rabit.py @@ -18,6 +18,7 @@ else: rbtlib = ctypes.cdll.LoadLibrary(RABIT_PATH) rbtlib.RabitGetRank.restype = ctypes.c_int rbtlib.RabitGetWorldSize.restype = ctypes.c_int +rbtlib.RabitVersionNumber.restype = ctypes.c_int # reduction operators MAX = 0 @@ -111,7 +112,7 @@ def broadcast(data, root): Arguments: data: anytype that can be pickled - input data, if current rank does not equal root, this can be None + input data, if current rank does not equal root, this can be None root: int rank of the node to broadcast data from Returns: @@ -123,8 +124,6 @@ def broadcast(data, root): assert data is not None, 'need to pass in data when broadcasting' s = pickle.dumps(data, protocol = pickle.HIGHEST_PROTOCOL) length.value = len(s) - dptr = (ctypes.c_char * length.value)() - dptr[:] = s # run first broadcast rbtlib.RabitBroadcast(ctypes.byref(length), ctypes.sizeof(ctypes.c_ulong), @@ -132,20 +131,40 @@ def broadcast(data, root): check_err__() if root != rank: dptr = (ctypes.c_char * length.value)() - # run second - rbtlib.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p), - length.value, root) - check_err__() - return pickle.loads(dptr.value) + # run second + rbtlib.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p), + length.value, root) + check_err__() + data = pickle.loads(dptr.raw) + del dptr + else: + rbtlib.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p), + length.value, root) + check_err__() + del s + return data -def allreduce(data, op, prepare_fun = None): +# enumeration of dtypes +DTYPE_ENUM__ = { + np.dtype('int8') : 0, + np.dtype('uint8') : 1, + np.dtype('int32') : 2, + np.dtype('uint32') : 3, + np.dtype('int64') : 4, + np.dtype('uint64') : 5, + np.dtype('float32') : 6, + np.dtype('float64') : 7 +} + +def allreduce(data, op): """ perform allreduce, return the result, this function is not thread-safe Arguments: data: numpy ndarray input data - op: reduction operators, can be MIN, MAX, SUM, BITOR - prepare_fun: lambda : + op: int + reduction operators, can be MIN, MAX, SUM, BITOR + prepare_fun: lambda Lazy preprocessing function, if it is not None, prepare_fun() will be called by the function before performing allreduce, to intialize the data If the result of Allreduce can be recovered directly, then prepare_fun will NOT be called @@ -157,25 +176,101 @@ def allreduce(data, op, prepare_fun = None): buf = data.ravel() if buf.base is data.base: buf = buf.copy() - if buf.dtype is np.dtype('int8'): - dtype = 0 - elif buf.dtype is np.dtype('uint8'): - dtype = 1 - elif buf.dtype is np.dtype('int32'): - dtype = 2 - elif buf.dtype is np.dtype('uint32'): - dtype = 3 - elif buf.dtype is np.dtype('int64'): - dtype = 4 - elif buf.dtype is np.dtype('uint64'): - dtype = 5 - elif buf.dtype is np.dtype('float32'): - dtype = 6 - elif buf.dtype is np.dtype('float64'): - dtype = 7 - else: + if buf.dtype not in DTYPE_ENUM__: raise Exception('data type %s not supported' % str(buf.dtype)) rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), - buf.size, dtype, op, None, None); + buf.size, DTYPE_ENUM__[dtype], + op, None, None); check_err__() return buf + + +def load_model__(ptr, length): + """ + Internal function used by the module, + unpickle a model from a buffer specified by ptr, length + Arguments: + ptr: ctypes.POINTER(ctypes._char) + pointer to the memory region of buffer + length: int + the length of buffer + """ + data = (ctypes.c_char * length).from_address(addressof(ptr.contents)) + return pickle.loads(data.raw) + +def load_checkpoint(with_local = False): + """ + load latest check point + Arguments: + with_local: boolean [default = False] + whether the checkpoint contains local model + Returns: + if with_local: return (version, gobal_model, local_model) + else return (version, gobal_model) + if returned version == 0, this means no model has been CheckPointed + and global_model, local_model returned will be None + """ + gp = ctypes.POINTER(ctypes.c_char)() + global_len = ctypes.c_ulong() + if with_local: + lp = ctypes.POINTER(ctypes.c_char)() + local_len = ctypes.c_ulong() + version = rbtlib.RabitLoadCheckPoint( + ctypes.byref(gp), + ctypes.byref(global_len), + ctypes.byref(lp), + ctypes.byref(local_len)) + check_err__() + if version == 0: + return (version, None, None) + return (version, + load_model__(gp, global_len.value), + load_model__(lp, local_len.value)) + else: + version = rbtlib.RabitLoadCheckPoint( + ctypes.byref(gp), + ctypes.byref(global_len), + None, None) + check_err__() + if version == 0: + return (version, None) + return (version, + load_model__(gp, global_len.value)) + +def checkpoint(global_model, local_model = None): + """ + checkpoint the model, meaning we finished a stage of execution + every time we call check point, there is a version number which will increase by one + + Arguments: + global_model: anytype that can be pickled + globally shared model/state when calling this function, + the caller need to gauranttees that global_model is the same in all nodes + local_model: anytype that can be pickled + local model, that is specific to current node/rank. + This can be None when no local state is needed. + local_model requires explicit replication of the model for fault-tolerance, + which will bring replication cost in checkpoint function, + while global_model do not need explicit replication. + It is recommended to use global_model if possible + """ + sg = pickle.dumps(global_model) + if local_model is None: + rbtlib.RabitCheckPoint(sg, len(sg), None, 0) + check_err__() + del sg; + else: + sl = pickle.dumps(local_model) + rbtlib.RabitCheckPoint(sg, len(sg), sl, len(sl)) + check_err__() + del sl; del sg; + +def version_number(): + """ + Returns version number of current stored model, + which means how many calls to CheckPoint we made so far + """ + ret = rbtlib.RabitVersionNumber() + check_err__() + return ret + diff --git a/wrapper/rabit_wrapper.cc b/wrapper/rabit_wrapper.cc index 60c53f93a..9ff25b09f 100644 --- a/wrapper/rabit_wrapper.cc +++ b/wrapper/rabit_wrapper.cc @@ -28,7 +28,7 @@ struct FHelper { void (*prepare_fun)(void *arg), void *prepare_arg) { utils::Error("DataType does not support bitwise or operation"); - } + } }; template inline void Allreduce_(void *sendrecvbuf_, @@ -116,6 +116,43 @@ inline void Allreduce(void *sendrecvbuf, default: utils::Error("unknown enum_op"); } } +// temporal memory for global and local model +std::string global_buffer, local_buffer; +// wrapper for serialization +struct ReadWrapper : public ISerializable { + std::string *p_str; + explicit ReadWrapper(std::string *p_str) + : p_str(p_str) {} + virtual void Load(IStream &fi) { + uint64_t sz; + utils::Assert(fi.Read(&sz, sizeof(sz)) != 0, + "Read pickle string"); + p_str->resize(sz); + if (sz != 0) { + utils::Assert(fi.Read(&(*p_str)[0], sizeof(char) * sz) != 0, + "Read pickle string"); + } + } + virtual void Save(IStream &fo) const { + utils::Error("not implemented"); + } +}; +struct WriteWrapper : public ISerializable { + const char *data; + size_t length; + explicit WriteWrapper(const char *data, + size_t length) + : data(data), length(length) { + } + virtual void Load(IStream &fi) { + utils::Error("not implemented"); + } + virtual void Save(IStream &fo) const { + uint64_t sz = static_cast(length); + fo.Write(&sz, sizeof(sz)); + fo.Write(data, length * sizeof(char)); + } +}; } // namespace wrapper } // namespace rabit extern "C" { @@ -161,4 +198,42 @@ extern "C" { static_cast(enum_op), prepare_fun, prepare_arg); } + int RabitLoadCheckPoint(char **out_global_model, + rbt_ulong *out_global_len, + char **out_local_model, + rbt_ulong *out_local_len) { + using rabit::BeginPtr; + using namespace rabit::wrapper; + ReadWrapper sg(&global_buffer); + ReadWrapper sl(&local_buffer); + int version; + if (out_local_model == NULL) { + version = rabit::LoadCheckPoint(&sg, NULL); + *out_global_model = BeginPtr(global_buffer); + *out_global_len = static_cast(global_buffer.length()); + } else { + version = rabit::LoadCheckPoint(&sg, &sl); + *out_global_model = BeginPtr(global_buffer); + *out_global_len = static_cast(global_buffer.length()); + *out_local_model = BeginPtr(local_buffer); + *out_local_len = static_cast(local_buffer.length()); + } + return version; + } + void RabitCheckPoint(const char *global_model, + rbt_ulong global_len, + const char *local_model, + rbt_ulong local_len) { + using namespace rabit::wrapper; + WriteWrapper sg(global_model, global_len); + WriteWrapper sl(local_model, local_len); + if (local_model == NULL) { + rabit::CheckPoint(&sg, NULL); + } else { + rabit::CheckPoint(&sg, &sl); + } + } + int RabitVersionNumber(void) { + return rabit::VersionNumber(); + } } diff --git a/wrapper/rabit_wrapper.h b/wrapper/rabit_wrapper.h index 823f199b9..39caa70b4 100644 --- a/wrapper/rabit_wrapper.h +++ b/wrapper/rabit_wrapper.h @@ -81,6 +81,44 @@ extern "C" { int enum_op, void (*prepare_fun)(void *arg), void *prepare_arg); + + /*! + * \brief load latest check point + * \param out_global_model hold output of serialized global_model + * \param out_global_len the output length of serialized global model + * \param out_local_model hold output of serialized local_model, can be NULL + * \param out_local_len the output length of serialized local model, can be NULL + * + * \return the version number of check point loaded + * if returned version == 0, this means no model has been CheckPointed + * nothing will be touched + */ + RABIT_DLL int RabitLoadCheckPoint(char **out_global_model, + rbt_ulong *out_global_len, + char **out_local_model, + rbt_ulong *out_local_len); + /*! + * \brief checkpoint the model, meaning we finished a stage of execution + * every time we call check point, there is a version number which will increase by one + * + * \param global_model hold content of serialized global_model + * \param global_len the content length of serialized global model + * \param local_model hold content of serialized local_model, can be NULL + * \param local_len the content length of serialized local model, can be NULL + * + * NOTE: local_model requires explicit replication of the model for fault-tolerance, which will + * bring replication cost in CheckPoint function. global_model do not need explicit replication. + * So only CheckPoint with global_model if possible + */ + RABIT_DLL void RabitCheckPoint(const char *global_model, + rbt_ulong global_len, + const char *local_model, + rbt_ulong local_len); + /*! + * \return version number of current stored model, + * which means how many calls to CheckPoint we made so far + */ + RABIT_DLL int RabitVersionNumber(void); #ifdef __cplusplus } // C #endif