Update datatable usage (#4123)
This commit is contained in:
parent
754fe8142b
commit
ff2d4c99fa
@ -49,7 +49,11 @@ except ImportError:
|
|||||||
|
|
||||||
# dt
|
# dt
|
||||||
try:
|
try:
|
||||||
from datatable import DataTable
|
import datatable
|
||||||
|
if hasattr(datatable, "Frame"):
|
||||||
|
DataTable = datatable.Frame
|
||||||
|
else:
|
||||||
|
DataTable = datatable.DataTable
|
||||||
DT_INSTALLED = True
|
DT_INSTALLED = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
||||||
|
|||||||
@ -287,10 +287,10 @@ def _maybe_dt_data(data, feature_names, feature_types):
|
|||||||
return data, feature_names, feature_types
|
return data, feature_names, feature_types
|
||||||
|
|
||||||
data_types_names = tuple(lt.name for lt in data.ltypes)
|
data_types_names = tuple(lt.name for lt in data.ltypes)
|
||||||
if not all(type_name in DT_TYPE_MAPPER for type_name in data_types_names):
|
bad_fields = [data.names[i]
|
||||||
bad_fields = [data.names[i] for i, type_name in
|
for i, type_name in enumerate(data_types_names)
|
||||||
enumerate(data_types_names) if type_name not in DT_TYPE_MAPPER]
|
if type_name not in DT_TYPE_MAPPER]
|
||||||
|
if bad_fields:
|
||||||
msg = """DataFrame.types for data must be int, float or bool.
|
msg = """DataFrame.types for data must be int, float or bool.
|
||||||
Did not expect the data types in fields """
|
Did not expect the data types in fields """
|
||||||
raise ValueError(msg + ', '.join(bad_fields))
|
raise ValueError(msg + ', '.join(bad_fields))
|
||||||
@ -317,7 +317,7 @@ def _maybe_dt_array(array):
|
|||||||
|
|
||||||
# below requires new dt version
|
# below requires new dt version
|
||||||
# extract first column
|
# extract first column
|
||||||
array = array.tonumpy()[:, 0].astype('float')
|
array = array.to_numpy()[:, 0].astype('float')
|
||||||
|
|
||||||
return array
|
return array
|
||||||
|
|
||||||
@ -340,7 +340,7 @@ class DMatrix(object):
|
|||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
data : string/numpy array/scipy.sparse/pd.DataFrame/DataTable
|
data : string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame
|
||||||
Data source of DMatrix.
|
Data source of DMatrix.
|
||||||
When data is string type, it represents the path libsvm format txt file,
|
When data is string type, it represents the path libsvm format txt file,
|
||||||
or binary file that xgboost can read from.
|
or binary file that xgboost can read from.
|
||||||
@ -497,16 +497,20 @@ class DMatrix(object):
|
|||||||
|
|
||||||
def _init_from_dt(self, data, nthread):
|
def _init_from_dt(self, data, nthread):
|
||||||
"""
|
"""
|
||||||
Initialize data from a DataTable
|
Initialize data from a datatable Frame.
|
||||||
"""
|
"""
|
||||||
cols = []
|
|
||||||
ptrs = (ctypes.c_void_p * data.ncols)()
|
ptrs = (ctypes.c_void_p * data.ncols)()
|
||||||
for icol in range(data.ncols):
|
if hasattr(data, "internal") and hasattr(data.internal, "column"):
|
||||||
col = data.internal.column(icol)
|
# datatable>0.8.0
|
||||||
cols.append(col)
|
for icol in range(data.ncols):
|
||||||
# int64_t (void*)
|
col = data.internal.column(icol)
|
||||||
ptr = col.data_pointer
|
ptr = col.data_pointer
|
||||||
ptrs[icol] = ctypes.c_void_p(ptr)
|
ptrs[icol] = ctypes.c_void_p(ptr)
|
||||||
|
else:
|
||||||
|
# datatable<=0.8.0
|
||||||
|
from datatable.internal import frame_column_data_r
|
||||||
|
for icol in range(data.ncols):
|
||||||
|
ptrs[icol] = frame_column_data_r(data, icol)
|
||||||
|
|
||||||
# always return stypes for dt ingestion
|
# always return stypes for dt ingestion
|
||||||
feature_type_strings = (ctypes.c_char_p * data.ncols)()
|
feature_type_strings = (ctypes.c_char_p * data.ncols)()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user