Take datatable as row major input. (#8472)

* Take datatable as row major input.

Try to avoid a transform with dense table.
This commit is contained in:
Jiaming Yuan
2022-11-24 09:20:13 +08:00
committed by GitHub
parent 284dcf8d22
commit e07245f110
4 changed files with 84 additions and 92 deletions

View File

@@ -44,8 +44,9 @@ _matrix_meta = {"base_margin", "label"}
def _warn_unused_missing(data: DataType, missing: Optional[FloatCompatible]) -> None:
if (missing is not None) and (not np.isnan(missing)):
warnings.warn(
'`missing` is not used for current input data type:' +
str(type(data)), UserWarning)
"`missing` is not used for current input data type:" + str(type(data)),
UserWarning,
)
def _check_complex(data: DataType) -> None:
@@ -459,12 +460,9 @@ def _from_pandas_series(
def _is_dt_df(data: DataType) -> bool:
return lazy_isinstance(data, 'datatable', 'Frame') or \
lazy_isinstance(data, 'datatable', 'DataTable')
_dt_type_mapper = {'bool': 'bool', 'int': 'int', 'real': 'float'}
_dt_type_mapper2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
return lazy_isinstance(data, "datatable", "Frame") or lazy_isinstance(
data, "datatable", "DataTable"
)
def _transform_dt_df(
@@ -475,8 +473,10 @@ def _transform_dt_df(
meta_type: Optional[NumpyDType] = None,
) -> Tuple[np.ndarray, Optional[FeatureNames], Optional[FeatureTypes]]:
"""Validate feature names and types if data table"""
_dt_type_mapper = {"bool": "bool", "int": "int", "real": "float"}
_dt_type_mapper2 = {"bool": "i", "int": "int", "real": "float"}
if meta and data.shape[1] > 1:
raise ValueError('DataTable for meta info cannot have multiple columns')
raise ValueError("DataTable for meta info cannot have multiple columns")
if meta:
meta_type = "float" if meta_type is None else meta_type
# below requires new dt version
@@ -485,23 +485,23 @@ def _transform_dt_df(
return data, None, None
data_types_names = tuple(lt.name for lt in data.ltypes)
bad_fields = [data.names[i]
for i, type_name in enumerate(data_types_names)
if type_name not in _dt_type_mapper]
bad_fields = [
data.names[i]
for i, type_name in enumerate(data_types_names)
if type_name not in _dt_type_mapper
]
if bad_fields:
msg = """DataFrame.types for data must be int, float or bool.
Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
raise ValueError(msg + ", ".join(bad_fields))
if feature_names is None and meta is None:
feature_names = data.names
# always return stypes for dt ingestion
if feature_types is not None:
raise ValueError(
'DataTable has own feature types, cannot pass them in.')
feature_types = np.vectorize(_dt_type_mapper2.get)(
data_types_names).tolist()
raise ValueError("DataTable has own feature types, cannot pass them in.")
feature_types = np.vectorize(_dt_type_mapper2.get)(data_types_names).tolist()
return data, feature_names, feature_types
@@ -517,7 +517,8 @@ def _from_dt_df(
if enable_categorical:
raise ValueError("categorical data in datatable is not supported yet.")
data, feature_names, feature_types = _transform_dt_df(
data, feature_names, feature_types, None, None)
data, feature_names, feature_types, None, None
)
ptrs = (ctypes.c_void_p * data.ncols)()
if hasattr(data, "internal") and hasattr(data.internal, "column"):
@@ -531,6 +532,7 @@ def _from_dt_df(
from datatable.internal import (
frame_column_data_r, # pylint: disable=no-name-in-module
)
for icol in range(data.ncols):
ptrs[icol] = frame_column_data_r(data, icol)
@@ -538,16 +540,21 @@ def _from_dt_df(
feature_type_strings = (ctypes.c_char_p * data.ncols)()
for icol in range(data.ncols):
feature_type_strings[icol] = ctypes.c_char_p(
data.stypes[icol].name.encode('utf-8'))
data.stypes[icol].name.encode("utf-8")
)
_warn_unused_missing(data, missing)
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromDT(
ptrs, feature_type_strings,
c_bst_ulong(data.shape[0]),
c_bst_ulong(data.shape[1]),
ctypes.byref(handle),
ctypes.c_int(nthread)))
_check_call(
_LIB.XGDMatrixCreateFromDT(
ptrs,
feature_type_strings,
c_bst_ulong(data.shape[0]),
c_bst_ulong(data.shape[1]),
ctypes.byref(handle),
ctypes.c_int(nthread),
)
)
return handle, feature_names, feature_types