Don't store DMatrix handle until it's initialized. (#4317)
* Use a temporary variable to store the handle. * Decode c++ error message. * Simple note about saved binary.
This commit is contained in:
parent
2f7087eba1
commit
82dca3c108
@ -175,7 +175,7 @@ def _check_call(ret):
|
||||
return value from API calls
|
||||
"""
|
||||
if ret != 0:
|
||||
raise XGBoostError(_LIB.XGBGetLastError())
|
||||
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
|
||||
|
||||
|
||||
def ctypes2numpy(cptr, length, dtype):
|
||||
@ -395,10 +395,11 @@ class DMatrix(object):
|
||||
DeprecationWarning)
|
||||
|
||||
if isinstance(data, STRING_TYPES):
|
||||
self.handle = ctypes.c_void_p()
|
||||
handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(data),
|
||||
ctypes.c_int(silent),
|
||||
ctypes.byref(self.handle)))
|
||||
ctypes.byref(handle)))
|
||||
self.handle = handle
|
||||
elif isinstance(data, scipy.sparse.csr_matrix):
|
||||
self._init_from_csr(data)
|
||||
elif isinstance(data, scipy.sparse.csc_matrix):
|
||||
@ -435,14 +436,15 @@ class DMatrix(object):
|
||||
"""
|
||||
if len(csr.indices) != len(csr.data):
|
||||
raise ValueError('length mismatch: {} vs {}'.format(len(csr.indices), len(csr.data)))
|
||||
self.handle = ctypes.c_void_p()
|
||||
handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGDMatrixCreateFromCSREx(c_array(ctypes.c_size_t, csr.indptr),
|
||||
c_array(ctypes.c_uint, csr.indices),
|
||||
c_array(ctypes.c_float, csr.data),
|
||||
ctypes.c_size_t(len(csr.indptr)),
|
||||
ctypes.c_size_t(len(csr.data)),
|
||||
ctypes.c_size_t(csr.shape[1]),
|
||||
ctypes.byref(self.handle)))
|
||||
ctypes.byref(handle)))
|
||||
self.handle = handle
|
||||
|
||||
def _init_from_csc(self, csc):
|
||||
"""
|
||||
@ -450,14 +452,15 @@ class DMatrix(object):
|
||||
"""
|
||||
if len(csc.indices) != len(csc.data):
|
||||
raise ValueError('length mismatch: {} vs {}'.format(len(csc.indices), len(csc.data)))
|
||||
self.handle = ctypes.c_void_p()
|
||||
handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGDMatrixCreateFromCSCEx(c_array(ctypes.c_size_t, csc.indptr),
|
||||
c_array(ctypes.c_uint, csc.indices),
|
||||
c_array(ctypes.c_float, csc.data),
|
||||
ctypes.c_size_t(len(csc.indptr)),
|
||||
ctypes.c_size_t(len(csc.data)),
|
||||
ctypes.c_size_t(csc.shape[0]),
|
||||
ctypes.byref(self.handle)))
|
||||
ctypes.byref(handle)))
|
||||
self.handle = handle
|
||||
|
||||
def _init_from_npy2d(self, mat, missing, nthread):
|
||||
"""
|
||||
@ -477,7 +480,7 @@ class DMatrix(object):
|
||||
# we try to avoid data copies if possible (reshape returns a view when possible
|
||||
# and we explicitly tell np.array to try and avoid copying)
|
||||
data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32)
|
||||
self.handle = ctypes.c_void_p()
|
||||
handle = ctypes.c_void_p()
|
||||
missing = missing if missing is not None else np.nan
|
||||
if nthread is None:
|
||||
_check_call(_LIB.XGDMatrixCreateFromMat(
|
||||
@ -485,15 +488,16 @@ class DMatrix(object):
|
||||
c_bst_ulong(mat.shape[0]),
|
||||
c_bst_ulong(mat.shape[1]),
|
||||
ctypes.c_float(missing),
|
||||
ctypes.byref(self.handle)))
|
||||
ctypes.byref(handle)))
|
||||
else:
|
||||
_check_call(_LIB.XGDMatrixCreateFromMat_omp(
|
||||
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
||||
c_bst_ulong(mat.shape[0]),
|
||||
c_bst_ulong(mat.shape[1]),
|
||||
ctypes.c_float(missing),
|
||||
ctypes.byref(self.handle),
|
||||
ctypes.byref(handle),
|
||||
nthread))
|
||||
self.handle = handle
|
||||
|
||||
def _init_from_dt(self, data, nthread):
|
||||
"""
|
||||
@ -517,14 +521,14 @@ class DMatrix(object):
|
||||
for icol in range(data.ncols):
|
||||
feature_type_strings[icol] = ctypes.c_char_p(data.stypes[icol].name.encode('utf-8'))
|
||||
|
||||
self.handle = ctypes.c_void_p()
|
||||
|
||||
handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGDMatrixCreateFromDT(
|
||||
ptrs, feature_type_strings,
|
||||
c_bst_ulong(data.shape[0]),
|
||||
c_bst_ulong(data.shape[1]),
|
||||
ctypes.byref(self.handle),
|
||||
ctypes.byref(handle),
|
||||
nthread))
|
||||
self.handle = handle
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "handle") and self.handle is not None:
|
||||
@ -646,7 +650,8 @@ class DMatrix(object):
|
||||
c_bst_ulong(len(data))))
|
||||
|
||||
def save_binary(self, fname, silent=True):
|
||||
"""Save DMatrix to an XGBoost buffer.
|
||||
"""Save DMatrix to an XGBoost buffer. Saved binary can be later loaded
|
||||
by providing the path to :py:func:`xgboost.DMatrix` as input.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user