Don't store DMatrix handle until it's initialized. (#4317)

* Use a temporary variable to store the handle.
* Decode c++ error message.
* Simple note about saved binary.
This commit is contained in:
Jiaming Yuan 2019-04-01 18:29:28 +08:00 committed by GitHub
parent 2f7087eba1
commit 82dca3c108
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -175,7 +175,7 @@ def _check_call(ret):
return value from API calls return value from API calls
""" """
if ret != 0: if ret != 0:
raise XGBoostError(_LIB.XGBGetLastError()) raise XGBoostError(py_str(_LIB.XGBGetLastError()))
def ctypes2numpy(cptr, length, dtype): def ctypes2numpy(cptr, length, dtype):
@ -395,10 +395,11 @@ class DMatrix(object):
DeprecationWarning) DeprecationWarning)
if isinstance(data, STRING_TYPES): if isinstance(data, STRING_TYPES):
self.handle = ctypes.c_void_p() handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(data), _check_call(_LIB.XGDMatrixCreateFromFile(c_str(data),
ctypes.c_int(silent), ctypes.c_int(silent),
ctypes.byref(self.handle))) ctypes.byref(handle)))
self.handle = handle
elif isinstance(data, scipy.sparse.csr_matrix): elif isinstance(data, scipy.sparse.csr_matrix):
self._init_from_csr(data) self._init_from_csr(data)
elif isinstance(data, scipy.sparse.csc_matrix): elif isinstance(data, scipy.sparse.csc_matrix):
@ -435,14 +436,15 @@ class DMatrix(object):
""" """
if len(csr.indices) != len(csr.data): if len(csr.indices) != len(csr.data):
raise ValueError('length mismatch: {} vs {}'.format(len(csr.indices), len(csr.data))) raise ValueError('length mismatch: {} vs {}'.format(len(csr.indices), len(csr.data)))
self.handle = ctypes.c_void_p() handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromCSREx(c_array(ctypes.c_size_t, csr.indptr), _check_call(_LIB.XGDMatrixCreateFromCSREx(c_array(ctypes.c_size_t, csr.indptr),
c_array(ctypes.c_uint, csr.indices), c_array(ctypes.c_uint, csr.indices),
c_array(ctypes.c_float, csr.data), c_array(ctypes.c_float, csr.data),
ctypes.c_size_t(len(csr.indptr)), ctypes.c_size_t(len(csr.indptr)),
ctypes.c_size_t(len(csr.data)), ctypes.c_size_t(len(csr.data)),
ctypes.c_size_t(csr.shape[1]), ctypes.c_size_t(csr.shape[1]),
ctypes.byref(self.handle))) ctypes.byref(handle)))
self.handle = handle
def _init_from_csc(self, csc): def _init_from_csc(self, csc):
""" """
@ -450,14 +452,15 @@ class DMatrix(object):
""" """
if len(csc.indices) != len(csc.data): if len(csc.indices) != len(csc.data):
raise ValueError('length mismatch: {} vs {}'.format(len(csc.indices), len(csc.data))) raise ValueError('length mismatch: {} vs {}'.format(len(csc.indices), len(csc.data)))
self.handle = ctypes.c_void_p() handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromCSCEx(c_array(ctypes.c_size_t, csc.indptr), _check_call(_LIB.XGDMatrixCreateFromCSCEx(c_array(ctypes.c_size_t, csc.indptr),
c_array(ctypes.c_uint, csc.indices), c_array(ctypes.c_uint, csc.indices),
c_array(ctypes.c_float, csc.data), c_array(ctypes.c_float, csc.data),
ctypes.c_size_t(len(csc.indptr)), ctypes.c_size_t(len(csc.indptr)),
ctypes.c_size_t(len(csc.data)), ctypes.c_size_t(len(csc.data)),
ctypes.c_size_t(csc.shape[0]), ctypes.c_size_t(csc.shape[0]),
ctypes.byref(self.handle))) ctypes.byref(handle)))
self.handle = handle
def _init_from_npy2d(self, mat, missing, nthread): def _init_from_npy2d(self, mat, missing, nthread):
""" """
@ -477,7 +480,7 @@ class DMatrix(object):
# we try to avoid data copies if possible (reshape returns a view when possible # we try to avoid data copies if possible (reshape returns a view when possible
# and we explicitly tell np.array to try and avoid copying) # and we explicitly tell np.array to try and avoid copying)
data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32) data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32)
self.handle = ctypes.c_void_p() handle = ctypes.c_void_p()
missing = missing if missing is not None else np.nan missing = missing if missing is not None else np.nan
if nthread is None: if nthread is None:
_check_call(_LIB.XGDMatrixCreateFromMat( _check_call(_LIB.XGDMatrixCreateFromMat(
@ -485,15 +488,16 @@ class DMatrix(object):
c_bst_ulong(mat.shape[0]), c_bst_ulong(mat.shape[0]),
c_bst_ulong(mat.shape[1]), c_bst_ulong(mat.shape[1]),
ctypes.c_float(missing), ctypes.c_float(missing),
ctypes.byref(self.handle))) ctypes.byref(handle)))
else: else:
_check_call(_LIB.XGDMatrixCreateFromMat_omp( _check_call(_LIB.XGDMatrixCreateFromMat_omp(
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
c_bst_ulong(mat.shape[0]), c_bst_ulong(mat.shape[0]),
c_bst_ulong(mat.shape[1]), c_bst_ulong(mat.shape[1]),
ctypes.c_float(missing), ctypes.c_float(missing),
ctypes.byref(self.handle), ctypes.byref(handle),
nthread)) nthread))
self.handle = handle
def _init_from_dt(self, data, nthread): def _init_from_dt(self, data, nthread):
""" """
@ -517,14 +521,14 @@ class DMatrix(object):
for icol in range(data.ncols): for icol in range(data.ncols):
feature_type_strings[icol] = ctypes.c_char_p(data.stypes[icol].name.encode('utf-8')) feature_type_strings[icol] = ctypes.c_char_p(data.stypes[icol].name.encode('utf-8'))
self.handle = ctypes.c_void_p() handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromDT( _check_call(_LIB.XGDMatrixCreateFromDT(
ptrs, feature_type_strings, ptrs, feature_type_strings,
c_bst_ulong(data.shape[0]), c_bst_ulong(data.shape[0]),
c_bst_ulong(data.shape[1]), c_bst_ulong(data.shape[1]),
ctypes.byref(self.handle), ctypes.byref(handle),
nthread)) nthread))
self.handle = handle
def __del__(self): def __del__(self):
if hasattr(self, "handle") and self.handle is not None: if hasattr(self, "handle") and self.handle is not None:
@ -646,7 +650,8 @@ class DMatrix(object):
c_bst_ulong(len(data)))) c_bst_ulong(len(data))))
def save_binary(self, fname, silent=True): def save_binary(self, fname, silent=True):
"""Save DMatrix to an XGBoost buffer. """Save DMatrix to an XGBoost buffer. Saved binary can be later loaded
by providing the path to :py:func:`xgboost.DMatrix` as input.
Parameters Parameters
---------- ----------