Update datatable usage (#4123)

This commit is contained in:
Pasha Stetsenko 2019-02-16 11:44:09 -08:00 committed by Jiaming Yuan
parent 754fe8142b
commit ff2d4c99fa
2 changed files with 23 additions and 15 deletions

View File

@ -49,7 +49,11 @@ except ImportError:
# dt # dt
try: try:
from datatable import DataTable import datatable
if hasattr(datatable, "Frame"):
DataTable = datatable.Frame
else:
DataTable = datatable.DataTable
DT_INSTALLED = True DT_INSTALLED = True
except ImportError: except ImportError:

View File

@ -287,10 +287,10 @@ def _maybe_dt_data(data, feature_names, feature_types):
return data, feature_names, feature_types return data, feature_names, feature_types
data_types_names = tuple(lt.name for lt in data.ltypes) data_types_names = tuple(lt.name for lt in data.ltypes)
if not all(type_name in DT_TYPE_MAPPER for type_name in data_types_names): bad_fields = [data.names[i]
bad_fields = [data.names[i] for i, type_name in for i, type_name in enumerate(data_types_names)
enumerate(data_types_names) if type_name not in DT_TYPE_MAPPER] if type_name not in DT_TYPE_MAPPER]
if bad_fields:
msg = """DataFrame.types for data must be int, float or bool. msg = """DataFrame.types for data must be int, float or bool.
Did not expect the data types in fields """ Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields)) raise ValueError(msg + ', '.join(bad_fields))
@ -317,7 +317,7 @@ def _maybe_dt_array(array):
# below requires new dt version # below requires new dt version
# extract first column # extract first column
array = array.tonumpy()[:, 0].astype('float') array = array.to_numpy()[:, 0].astype('float')
return array return array
@ -340,7 +340,7 @@ class DMatrix(object):
""" """
Parameters Parameters
---------- ----------
data : string/numpy array/scipy.sparse/pd.DataFrame/DataTable data : string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame
Data source of DMatrix. Data source of DMatrix.
When data is string type, it represents the path libsvm format txt file, When data is string type, it represents the path libsvm format txt file,
or binary file that xgboost can read from. or binary file that xgboost can read from.
@ -497,16 +497,20 @@ class DMatrix(object):
def _init_from_dt(self, data, nthread): def _init_from_dt(self, data, nthread):
""" """
Initialize data from a DataTable Initialize data from a datatable Frame.
""" """
cols = []
ptrs = (ctypes.c_void_p * data.ncols)() ptrs = (ctypes.c_void_p * data.ncols)()
for icol in range(data.ncols): if hasattr(data, "internal") and hasattr(data.internal, "column"):
col = data.internal.column(icol) # datatable>0.8.0
cols.append(col) for icol in range(data.ncols):
# int64_t (void*) col = data.internal.column(icol)
ptr = col.data_pointer ptr = col.data_pointer
ptrs[icol] = ctypes.c_void_p(ptr) ptrs[icol] = ctypes.c_void_p(ptr)
else:
# datatable<=0.8.0
from datatable.internal import frame_column_data_r
for icol in range(data.ncols):
ptrs[icol] = frame_column_data_r(data, icol)
# always return stypes for dt ingestion # always return stypes for dt ingestion
feature_type_strings = (ctypes.c_char_p * data.ncols)() feature_type_strings = (ctypes.c_char_p * data.ncols)()