diff --git a/python-package/xgboost/compat.py b/python-package/xgboost/compat.py index 6c3b9071c..4b868d7bf 100644 --- a/python-package/xgboost/compat.py +++ b/python-package/xgboost/compat.py @@ -49,7 +49,11 @@ except ImportError: # dt try: - from datatable import DataTable + import datatable + if hasattr(datatable, "Frame"): + DataTable = datatable.Frame + else: + DataTable = datatable.DataTable DT_INSTALLED = True except ImportError: diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 8159bea57..c03079d1e 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -287,10 +287,10 @@ def _maybe_dt_data(data, feature_names, feature_types): return data, feature_names, feature_types data_types_names = tuple(lt.name for lt in data.ltypes) - if not all(type_name in DT_TYPE_MAPPER for type_name in data_types_names): - bad_fields = [data.names[i] for i, type_name in - enumerate(data_types_names) if type_name not in DT_TYPE_MAPPER] - + bad_fields = [data.names[i] + for i, type_name in enumerate(data_types_names) + if type_name not in DT_TYPE_MAPPER] + if bad_fields: msg = """DataFrame.types for data must be int, float or bool. Did not expect the data types in fields """ raise ValueError(msg + ', '.join(bad_fields)) @@ -317,7 +317,7 @@ def _maybe_dt_array(array): # below requires new dt version # extract first column - array = array.tonumpy()[:, 0].astype('float') + array = array.to_numpy()[:, 0].astype('float') return array @@ -340,7 +340,7 @@ class DMatrix(object): """ Parameters ---------- - data : string/numpy array/scipy.sparse/pd.DataFrame/DataTable + data : string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame Data source of DMatrix. When data is string type, it represents the path libsvm format txt file, or binary file that xgboost can read from. @@ -497,16 +497,20 @@ class DMatrix(object): def _init_from_dt(self, data, nthread): """ - Initialize data from a DataTable + Initialize data from a datatable Frame. """ - cols = [] ptrs = (ctypes.c_void_p * data.ncols)() - for icol in range(data.ncols): - col = data.internal.column(icol) - cols.append(col) - # int64_t (void*) - ptr = col.data_pointer - ptrs[icol] = ctypes.c_void_p(ptr) + if hasattr(data, "internal") and hasattr(data.internal, "column"): + # datatable>0.8.0 + for icol in range(data.ncols): + col = data.internal.column(icol) + ptr = col.data_pointer + ptrs[icol] = ctypes.c_void_p(ptr) + else: + # datatable<=0.8.0 + from datatable.internal import frame_column_data_r + for icol in range(data.ncols): + ptrs[icol] = frame_column_data_r(data, icol) # always return stypes for dt ingestion feature_type_strings = (ctypes.c_char_p * data.ncols)()