Don't store DMatrix handle until it's initialized. (#4317)
* Use a temporary variable to store the handle. * Decode c++ error message. * Simple note about saved binary.
This commit is contained in:
parent
2f7087eba1
commit
82dca3c108
@ -175,7 +175,7 @@ def _check_call(ret):
|
|||||||
return value from API calls
|
return value from API calls
|
||||||
"""
|
"""
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
raise XGBoostError(_LIB.XGBGetLastError())
|
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
|
||||||
|
|
||||||
|
|
||||||
def ctypes2numpy(cptr, length, dtype):
|
def ctypes2numpy(cptr, length, dtype):
|
||||||
@ -395,10 +395,11 @@ class DMatrix(object):
|
|||||||
DeprecationWarning)
|
DeprecationWarning)
|
||||||
|
|
||||||
if isinstance(data, STRING_TYPES):
|
if isinstance(data, STRING_TYPES):
|
||||||
self.handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(data),
|
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(data),
|
||||||
ctypes.c_int(silent),
|
ctypes.c_int(silent),
|
||||||
ctypes.byref(self.handle)))
|
ctypes.byref(handle)))
|
||||||
|
self.handle = handle
|
||||||
elif isinstance(data, scipy.sparse.csr_matrix):
|
elif isinstance(data, scipy.sparse.csr_matrix):
|
||||||
self._init_from_csr(data)
|
self._init_from_csr(data)
|
||||||
elif isinstance(data, scipy.sparse.csc_matrix):
|
elif isinstance(data, scipy.sparse.csc_matrix):
|
||||||
@ -435,14 +436,15 @@ class DMatrix(object):
|
|||||||
"""
|
"""
|
||||||
if len(csr.indices) != len(csr.data):
|
if len(csr.indices) != len(csr.data):
|
||||||
raise ValueError('length mismatch: {} vs {}'.format(len(csr.indices), len(csr.data)))
|
raise ValueError('length mismatch: {} vs {}'.format(len(csr.indices), len(csr.data)))
|
||||||
self.handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
_check_call(_LIB.XGDMatrixCreateFromCSREx(c_array(ctypes.c_size_t, csr.indptr),
|
_check_call(_LIB.XGDMatrixCreateFromCSREx(c_array(ctypes.c_size_t, csr.indptr),
|
||||||
c_array(ctypes.c_uint, csr.indices),
|
c_array(ctypes.c_uint, csr.indices),
|
||||||
c_array(ctypes.c_float, csr.data),
|
c_array(ctypes.c_float, csr.data),
|
||||||
ctypes.c_size_t(len(csr.indptr)),
|
ctypes.c_size_t(len(csr.indptr)),
|
||||||
ctypes.c_size_t(len(csr.data)),
|
ctypes.c_size_t(len(csr.data)),
|
||||||
ctypes.c_size_t(csr.shape[1]),
|
ctypes.c_size_t(csr.shape[1]),
|
||||||
ctypes.byref(self.handle)))
|
ctypes.byref(handle)))
|
||||||
|
self.handle = handle
|
||||||
|
|
||||||
def _init_from_csc(self, csc):
|
def _init_from_csc(self, csc):
|
||||||
"""
|
"""
|
||||||
@ -450,14 +452,15 @@ class DMatrix(object):
|
|||||||
"""
|
"""
|
||||||
if len(csc.indices) != len(csc.data):
|
if len(csc.indices) != len(csc.data):
|
||||||
raise ValueError('length mismatch: {} vs {}'.format(len(csc.indices), len(csc.data)))
|
raise ValueError('length mismatch: {} vs {}'.format(len(csc.indices), len(csc.data)))
|
||||||
self.handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
_check_call(_LIB.XGDMatrixCreateFromCSCEx(c_array(ctypes.c_size_t, csc.indptr),
|
_check_call(_LIB.XGDMatrixCreateFromCSCEx(c_array(ctypes.c_size_t, csc.indptr),
|
||||||
c_array(ctypes.c_uint, csc.indices),
|
c_array(ctypes.c_uint, csc.indices),
|
||||||
c_array(ctypes.c_float, csc.data),
|
c_array(ctypes.c_float, csc.data),
|
||||||
ctypes.c_size_t(len(csc.indptr)),
|
ctypes.c_size_t(len(csc.indptr)),
|
||||||
ctypes.c_size_t(len(csc.data)),
|
ctypes.c_size_t(len(csc.data)),
|
||||||
ctypes.c_size_t(csc.shape[0]),
|
ctypes.c_size_t(csc.shape[0]),
|
||||||
ctypes.byref(self.handle)))
|
ctypes.byref(handle)))
|
||||||
|
self.handle = handle
|
||||||
|
|
||||||
def _init_from_npy2d(self, mat, missing, nthread):
|
def _init_from_npy2d(self, mat, missing, nthread):
|
||||||
"""
|
"""
|
||||||
@ -477,7 +480,7 @@ class DMatrix(object):
|
|||||||
# we try to avoid data copies if possible (reshape returns a view when possible
|
# we try to avoid data copies if possible (reshape returns a view when possible
|
||||||
# and we explicitly tell np.array to try and avoid copying)
|
# and we explicitly tell np.array to try and avoid copying)
|
||||||
data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32)
|
data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32)
|
||||||
self.handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
missing = missing if missing is not None else np.nan
|
missing = missing if missing is not None else np.nan
|
||||||
if nthread is None:
|
if nthread is None:
|
||||||
_check_call(_LIB.XGDMatrixCreateFromMat(
|
_check_call(_LIB.XGDMatrixCreateFromMat(
|
||||||
@ -485,15 +488,16 @@ class DMatrix(object):
|
|||||||
c_bst_ulong(mat.shape[0]),
|
c_bst_ulong(mat.shape[0]),
|
||||||
c_bst_ulong(mat.shape[1]),
|
c_bst_ulong(mat.shape[1]),
|
||||||
ctypes.c_float(missing),
|
ctypes.c_float(missing),
|
||||||
ctypes.byref(self.handle)))
|
ctypes.byref(handle)))
|
||||||
else:
|
else:
|
||||||
_check_call(_LIB.XGDMatrixCreateFromMat_omp(
|
_check_call(_LIB.XGDMatrixCreateFromMat_omp(
|
||||||
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
||||||
c_bst_ulong(mat.shape[0]),
|
c_bst_ulong(mat.shape[0]),
|
||||||
c_bst_ulong(mat.shape[1]),
|
c_bst_ulong(mat.shape[1]),
|
||||||
ctypes.c_float(missing),
|
ctypes.c_float(missing),
|
||||||
ctypes.byref(self.handle),
|
ctypes.byref(handle),
|
||||||
nthread))
|
nthread))
|
||||||
|
self.handle = handle
|
||||||
|
|
||||||
def _init_from_dt(self, data, nthread):
|
def _init_from_dt(self, data, nthread):
|
||||||
"""
|
"""
|
||||||
@ -517,14 +521,14 @@ class DMatrix(object):
|
|||||||
for icol in range(data.ncols):
|
for icol in range(data.ncols):
|
||||||
feature_type_strings[icol] = ctypes.c_char_p(data.stypes[icol].name.encode('utf-8'))
|
feature_type_strings[icol] = ctypes.c_char_p(data.stypes[icol].name.encode('utf-8'))
|
||||||
|
|
||||||
self.handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
|
|
||||||
_check_call(_LIB.XGDMatrixCreateFromDT(
|
_check_call(_LIB.XGDMatrixCreateFromDT(
|
||||||
ptrs, feature_type_strings,
|
ptrs, feature_type_strings,
|
||||||
c_bst_ulong(data.shape[0]),
|
c_bst_ulong(data.shape[0]),
|
||||||
c_bst_ulong(data.shape[1]),
|
c_bst_ulong(data.shape[1]),
|
||||||
ctypes.byref(self.handle),
|
ctypes.byref(handle),
|
||||||
nthread))
|
nthread))
|
||||||
|
self.handle = handle
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if hasattr(self, "handle") and self.handle is not None:
|
if hasattr(self, "handle") and self.handle is not None:
|
||||||
@ -646,7 +650,8 @@ class DMatrix(object):
|
|||||||
c_bst_ulong(len(data))))
|
c_bst_ulong(len(data))))
|
||||||
|
|
||||||
def save_binary(self, fname, silent=True):
|
def save_binary(self, fname, silent=True):
|
||||||
"""Save DMatrix to an XGBoost buffer.
|
"""Save DMatrix to an XGBoost buffer. Saved binary can be later loaded
|
||||||
|
by providing the path to :py:func:`xgboost.DMatrix` as input.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user