Move feature names and types of DMatrix from Python to C++. (#5858)
* Add thread local return entry for DMatrix. * Save feature name and feature type in binary file. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -305,12 +305,9 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
DMatrix is a internal data structure that used by XGBoost
|
||||
which is optimized for both memory efficiency and training speed.
|
||||
You can construct DMatrix from numpy.arrays
|
||||
You can construct DMatrix from multiple different sources of data.
|
||||
"""
|
||||
|
||||
_feature_names = None # for previous version's pickle
|
||||
_feature_types = None
|
||||
|
||||
def __init__(self, data, label=None, weight=None, base_margin=None,
|
||||
missing=None,
|
||||
silent=False,
|
||||
@@ -362,11 +359,6 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
# force into void_p, mac need to pass things in as void_p
|
||||
if data is None:
|
||||
self.handle = None
|
||||
|
||||
if feature_names is not None:
|
||||
self._feature_names = feature_names
|
||||
if feature_types is not None:
|
||||
self._feature_types = feature_types
|
||||
return
|
||||
|
||||
handler = self.get_data_handler(data)
|
||||
@@ -666,14 +658,16 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
res : DMatrix
|
||||
A new DMatrix containing only selected indices.
|
||||
"""
|
||||
res = DMatrix(None, feature_names=self.feature_names,
|
||||
feature_types=self.feature_types)
|
||||
res = DMatrix(None)
|
||||
res.handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGDMatrixSliceDMatrixEx(self.handle,
|
||||
c_array(ctypes.c_int, rindex),
|
||||
c_bst_ulong(len(rindex)),
|
||||
ctypes.byref(res.handle),
|
||||
ctypes.c_int(1 if allow_groups else 0)))
|
||||
_check_call(_LIB.XGDMatrixSliceDMatrixEx(
|
||||
self.handle,
|
||||
c_array(ctypes.c_int, rindex),
|
||||
c_bst_ulong(len(rindex)),
|
||||
ctypes.byref(res.handle),
|
||||
ctypes.c_int(1 if allow_groups else 0)))
|
||||
res.feature_names = self.feature_names
|
||||
res.feature_types = self.feature_types
|
||||
return res
|
||||
|
||||
@property
|
||||
@@ -684,20 +678,17 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
-------
|
||||
feature_names : list or None
|
||||
"""
|
||||
if self._feature_names is None:
|
||||
self._feature_names = ['f{0}'.format(i)
|
||||
for i in range(self.num_col())]
|
||||
return self._feature_names
|
||||
|
||||
@property
|
||||
def feature_types(self):
|
||||
"""Get feature types (column types).
|
||||
|
||||
Returns
|
||||
-------
|
||||
feature_types : list or None
|
||||
"""
|
||||
return self._feature_types
|
||||
length = c_bst_ulong()
|
||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||
_check_call(_LIB.XGDMatrixGetStrFeatureInfo(self.handle,
|
||||
c_str('feature_name'),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr)))
|
||||
feature_names = from_cstr_to_pystr(sarr, length)
|
||||
if not feature_names:
|
||||
feature_names = ['f{0}'.format(i)
|
||||
for i in range(self.num_col())]
|
||||
return feature_names
|
||||
|
||||
@feature_names.setter
|
||||
def feature_names(self, feature_names):
|
||||
@@ -728,10 +719,41 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
not any(x in f for x in set(('[', ']', '<')))
|
||||
for f in feature_names):
|
||||
raise ValueError('feature_names must be string, and may not contain [, ] or <')
|
||||
c_feature_names = [bytes(f, encoding='utf-8')
|
||||
for f in feature_names]
|
||||
c_feature_names = (ctypes.c_char_p *
|
||||
len(c_feature_names))(*c_feature_names)
|
||||
_check_call(_LIB.XGDMatrixSetStrFeatureInfo(
|
||||
self.handle, c_str('feature_name'),
|
||||
c_feature_names,
|
||||
c_bst_ulong(len(feature_names))))
|
||||
else:
|
||||
# reset feature_types also
|
||||
_check_call(_LIB.XGDMatrixSetStrFeatureInfo(
|
||||
self.handle,
|
||||
c_str('feature_name'),
|
||||
None,
|
||||
c_bst_ulong(0)))
|
||||
self.feature_types = None
|
||||
self._feature_names = feature_names
|
||||
|
||||
@property
|
||||
def feature_types(self):
|
||||
"""Get feature types (column types).
|
||||
|
||||
Returns
|
||||
-------
|
||||
feature_types : list or None
|
||||
"""
|
||||
length = c_bst_ulong()
|
||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||
_check_call(_LIB.XGDMatrixGetStrFeatureInfo(self.handle,
|
||||
c_str('feature_type'),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr)))
|
||||
res = from_cstr_to_pystr(sarr, length)
|
||||
if not res:
|
||||
return None
|
||||
return res
|
||||
|
||||
@feature_types.setter
|
||||
def feature_types(self, feature_types):
|
||||
@@ -746,14 +768,12 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
Labels for features. None will reset existing feature names
|
||||
"""
|
||||
if feature_types is not None:
|
||||
if self._feature_names is None:
|
||||
msg = 'Unable to set feature types before setting names'
|
||||
raise ValueError(msg)
|
||||
|
||||
if not isinstance(feature_types, (list, str)):
|
||||
raise TypeError(
|
||||
'feature_types must be string or list of strings')
|
||||
if isinstance(feature_types, STRING_TYPES):
|
||||
# single string will be applied to all columns
|
||||
feature_types = [feature_types] * self.num_col()
|
||||
|
||||
try:
|
||||
if not isinstance(feature_types, str):
|
||||
feature_types = list(feature_types)
|
||||
@@ -761,16 +781,25 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
feature_types = [feature_types]
|
||||
except TypeError:
|
||||
feature_types = [feature_types]
|
||||
c_feature_types = [bytes(f, encoding='utf-8')
|
||||
for f in feature_types]
|
||||
c_feature_types = (ctypes.c_char_p *
|
||||
len(c_feature_types))(*c_feature_types)
|
||||
_check_call(_LIB.XGDMatrixSetStrFeatureInfo(
|
||||
self.handle, c_str('feature_type'),
|
||||
c_feature_types,
|
||||
c_bst_ulong(len(feature_types))))
|
||||
|
||||
if len(feature_types) != self.num_col():
|
||||
msg = 'feature_types must have the same length as data'
|
||||
raise ValueError(msg)
|
||||
|
||||
valid = ('int', 'float', 'i', 'q')
|
||||
if not all(isinstance(f, STRING_TYPES) and f in valid
|
||||
for f in feature_types):
|
||||
raise ValueError('All feature_names must be {int, float, i, q}')
|
||||
self._feature_types = feature_types
|
||||
else:
|
||||
# Reset.
|
||||
_check_call(_LIB.XGDMatrixSetStrFeatureInfo(
|
||||
self.handle,
|
||||
c_str('feature_type'),
|
||||
None,
|
||||
c_bst_ulong(0)))
|
||||
|
||||
|
||||
class DeviceQuantileDMatrix(DMatrix):
|
||||
|
||||
@@ -372,7 +372,7 @@ class DTHandler(DataHandler):
|
||||
raise ValueError(
|
||||
'DataTable has own feature types, cannot pass them in.')
|
||||
feature_types = np.vectorize(self.dt_type_mapper2.get)(
|
||||
data_types_names)
|
||||
data_types_names).tolist()
|
||||
|
||||
return data, feature_names, feature_types
|
||||
|
||||
|
||||
Reference in New Issue
Block a user