diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index f65213955..ed0b1c2df 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -97,11 +97,12 @@ def ctypes2numpy(cptr, length, dtype): def ctypes2buffer(cptr, length): if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)): raise RuntimeError('expected char pointer') - res = np.zeros(length, dtype='uint8') - if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]): + res = bytearray(length) + rptr = (ctypes.c_char * length).from_buffer(res) + if not ctypes.memmove(rptr, cptr, length): raise RuntimeError('memmove failed') return res - + def c_str(string): return ctypes.c_char_p(string.encode('utf-8')) @@ -886,7 +887,7 @@ class XGBModel(XGBModelBase): def __setstate__(self, state): bst = state["_Booster"] if bst is not None: - state["_Booster"] = Booster(model_file=booster) + state["_Booster"] = Booster(model_file=bst) self.__dict__.update(state) def booster(self):