Update datatable usage (#4123)

This commit is contained in:
Pasha Stetsenko 2019-02-16 11:44:09 -08:00 committed by Jiaming Yuan
parent 754fe8142b
commit ff2d4c99fa
2 changed files with 23 additions and 15 deletions

View File

@ -49,7 +49,11 @@ except ImportError:
# dt
try:
from datatable import DataTable
import datatable
if hasattr(datatable, "Frame"):
DataTable = datatable.Frame
else:
DataTable = datatable.DataTable
DT_INSTALLED = True
except ImportError:

View File

@ -287,10 +287,10 @@ def _maybe_dt_data(data, feature_names, feature_types):
return data, feature_names, feature_types
data_types_names = tuple(lt.name for lt in data.ltypes)
if not all(type_name in DT_TYPE_MAPPER for type_name in data_types_names):
bad_fields = [data.names[i] for i, type_name in
enumerate(data_types_names) if type_name not in DT_TYPE_MAPPER]
bad_fields = [data.names[i]
for i, type_name in enumerate(data_types_names)
if type_name not in DT_TYPE_MAPPER]
if bad_fields:
msg = """DataFrame.types for data must be int, float or bool.
Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
@ -317,7 +317,7 @@ def _maybe_dt_array(array):
# below requires new dt version
# extract first column
array = array.tonumpy()[:, 0].astype('float')
array = array.to_numpy()[:, 0].astype('float')
return array
@ -340,7 +340,7 @@ class DMatrix(object):
"""
Parameters
----------
data : string/numpy array/scipy.sparse/pd.DataFrame/DataTable
data : string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame
Data source of DMatrix.
When data is string type, it represents the path libsvm format txt file,
or binary file that xgboost can read from.
@ -497,16 +497,20 @@ class DMatrix(object):
def _init_from_dt(self, data, nthread):
"""
Initialize data from a DataTable
Initialize data from a datatable Frame.
"""
cols = []
ptrs = (ctypes.c_void_p * data.ncols)()
if hasattr(data, "internal") and hasattr(data.internal, "column"):
# datatable>0.8.0
for icol in range(data.ncols):
col = data.internal.column(icol)
cols.append(col)
# int64_t (void*)
ptr = col.data_pointer
ptrs[icol] = ctypes.c_void_p(ptr)
else:
# datatable<=0.8.0
from datatable.internal import frame_column_data_r
for icol in range(data.ncols):
ptrs[icol] = frame_column_data_r(data, icol)
# always return stypes for dt ingestion
feature_type_strings = (ctypes.c_char_p * data.ncols)()