Update datatable usage (#4123)
This commit is contained in:
parent
754fe8142b
commit
ff2d4c99fa
@ -49,7 +49,11 @@ except ImportError:
|
||||
|
||||
# dt
|
||||
try:
|
||||
from datatable import DataTable
|
||||
import datatable
|
||||
if hasattr(datatable, "Frame"):
|
||||
DataTable = datatable.Frame
|
||||
else:
|
||||
DataTable = datatable.DataTable
|
||||
DT_INSTALLED = True
|
||||
except ImportError:
|
||||
|
||||
|
||||
@ -287,10 +287,10 @@ def _maybe_dt_data(data, feature_names, feature_types):
|
||||
return data, feature_names, feature_types
|
||||
|
||||
data_types_names = tuple(lt.name for lt in data.ltypes)
|
||||
if not all(type_name in DT_TYPE_MAPPER for type_name in data_types_names):
|
||||
bad_fields = [data.names[i] for i, type_name in
|
||||
enumerate(data_types_names) if type_name not in DT_TYPE_MAPPER]
|
||||
|
||||
bad_fields = [data.names[i]
|
||||
for i, type_name in enumerate(data_types_names)
|
||||
if type_name not in DT_TYPE_MAPPER]
|
||||
if bad_fields:
|
||||
msg = """DataFrame.types for data must be int, float or bool.
|
||||
Did not expect the data types in fields """
|
||||
raise ValueError(msg + ', '.join(bad_fields))
|
||||
@ -317,7 +317,7 @@ def _maybe_dt_array(array):
|
||||
|
||||
# below requires new dt version
|
||||
# extract first column
|
||||
array = array.tonumpy()[:, 0].astype('float')
|
||||
array = array.to_numpy()[:, 0].astype('float')
|
||||
|
||||
return array
|
||||
|
||||
@ -340,7 +340,7 @@ class DMatrix(object):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
data : string/numpy array/scipy.sparse/pd.DataFrame/DataTable
|
||||
data : string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame
|
||||
Data source of DMatrix.
|
||||
When data is string type, it represents the path libsvm format txt file,
|
||||
or binary file that xgboost can read from.
|
||||
@ -497,16 +497,20 @@ class DMatrix(object):
|
||||
|
||||
def _init_from_dt(self, data, nthread):
|
||||
"""
|
||||
Initialize data from a DataTable
|
||||
Initialize data from a datatable Frame.
|
||||
"""
|
||||
cols = []
|
||||
ptrs = (ctypes.c_void_p * data.ncols)()
|
||||
if hasattr(data, "internal") and hasattr(data.internal, "column"):
|
||||
# datatable>0.8.0
|
||||
for icol in range(data.ncols):
|
||||
col = data.internal.column(icol)
|
||||
cols.append(col)
|
||||
# int64_t (void*)
|
||||
ptr = col.data_pointer
|
||||
ptrs[icol] = ctypes.c_void_p(ptr)
|
||||
else:
|
||||
# datatable<=0.8.0
|
||||
from datatable.internal import frame_column_data_r
|
||||
for icol in range(data.ncols):
|
||||
ptrs[icol] = frame_column_data_r(data, icol)
|
||||
|
||||
# always return stypes for dt ingestion
|
||||
feature_type_strings = (ctypes.c_char_p * data.ncols)()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user