Allow import via python datatable. (#3272)
* Allow import via python datatable. * Write unit tests * Refactor dt API functions * Refactor python code * Lint fixes * Address review comments
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
eecf341ea7
commit
9ac163d0bb
@@ -38,7 +38,7 @@ try:
|
||||
except ImportError:
|
||||
|
||||
class MultiIndex(object):
|
||||
""" dummy for pandas.MultiIndex """
|
||||
""" dummy for pandas.MultiIndex """
|
||||
pass
|
||||
|
||||
class DataFrame(object):
|
||||
@@ -47,6 +47,18 @@ except ImportError:
|
||||
|
||||
PANDAS_INSTALLED = False
|
||||
|
||||
# dt
|
||||
try:
|
||||
from datatable import DataTable
|
||||
DT_INSTALLED = True
|
||||
except ImportError:
|
||||
|
||||
class DataTable(object):
|
||||
""" dummy for datatable.DataTable """
|
||||
pass
|
||||
|
||||
DT_INSTALLED = False
|
||||
|
||||
# sklearn
|
||||
try:
|
||||
from sklearn.base import BaseEstimator
|
||||
|
||||
@@ -4,19 +4,18 @@
|
||||
"""Core XGBoost Library."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import sys
|
||||
import os
|
||||
import ctypes
|
||||
import collections
|
||||
import ctypes
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import scipy.sparse
|
||||
|
||||
from .compat import STRING_TYPES, PY3, DataFrame, MultiIndex, py_str, PANDAS_INSTALLED, DataTable
|
||||
from .libpath import find_lib_path
|
||||
|
||||
from .compat import STRING_TYPES, PY3, DataFrame, MultiIndex, py_str, PANDAS_INSTALLED
|
||||
|
||||
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
|
||||
c_bst_ulong = ctypes.c_uint64
|
||||
|
||||
@@ -182,7 +181,7 @@ def _maybe_pandas_data(data, feature_names, feature_types):
|
||||
enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER]
|
||||
|
||||
msg = """DataFrame.dtypes for data must be int, float or bool.
|
||||
Did not expect the data types in fields """
|
||||
Did not expect the data types in fields """
|
||||
raise ValueError(msg + ', '.join(bad_fields))
|
||||
|
||||
if feature_names is None:
|
||||
@@ -219,6 +218,54 @@ def _maybe_pandas_label(label):
|
||||
return label
|
||||
|
||||
|
||||
DT_TYPE_MAPPER = {'bool': 'bool', 'int': 'int', 'real': 'float'}
|
||||
|
||||
DT_TYPE_MAPPER2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
|
||||
|
||||
|
||||
def _maybe_dt_data(data, feature_names, feature_types):
|
||||
"""
|
||||
Validate feature names and types if data table
|
||||
"""
|
||||
if not isinstance(data, DataTable):
|
||||
return data, feature_names, feature_types
|
||||
|
||||
data_types_names = tuple(lt.name for lt in data.ltypes)
|
||||
if not all(type_name in DT_TYPE_MAPPER for type_name in data_types_names):
|
||||
bad_fields = [data.names[i] for i, type_name in
|
||||
enumerate(data_types_names) if type_name not in DT_TYPE_MAPPER]
|
||||
|
||||
msg = """DataFrame.types for data must be int, float or bool.
|
||||
Did not expect the data types in fields """
|
||||
raise ValueError(msg + ', '.join(bad_fields))
|
||||
|
||||
if feature_names is None:
|
||||
feature_names = data.names
|
||||
|
||||
# always return stypes for dt ingestion
|
||||
if feature_types is not None:
|
||||
raise ValueError('DataTable has own feature types, cannot pass them in')
|
||||
else:
|
||||
feature_types = np.vectorize(DT_TYPE_MAPPER2.get)(data_types_names)
|
||||
|
||||
return data, feature_names, feature_types
|
||||
|
||||
|
||||
def _maybe_dt_array(array):
|
||||
""" Extract numpy array from single column data table """
|
||||
if not isinstance(array, DataTable) or array is None:
|
||||
return array
|
||||
|
||||
if array.shape[1] > 1:
|
||||
raise ValueError('DataTable for label or weight cannot have multiple columns')
|
||||
|
||||
# below requires new dt version
|
||||
# extract first column
|
||||
array = array.tonumpy()[:, 0].astype('float')
|
||||
|
||||
return array
|
||||
|
||||
|
||||
class DMatrix(object):
|
||||
"""Data Matrix used in XGBoost.
|
||||
|
||||
@@ -237,7 +284,7 @@ class DMatrix(object):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
data : string/numpy array/scipy.sparse/pd.DataFrame
|
||||
data : string/numpy array/scipy.sparse/pd.DataFrame/DataTable
|
||||
Data source of DMatrix.
|
||||
When data is string type, it represents the path libsvm format txt file,
|
||||
or binary file that xgboost can read from.
|
||||
@@ -266,7 +313,13 @@ class DMatrix(object):
|
||||
data, feature_names, feature_types = _maybe_pandas_data(data,
|
||||
feature_names,
|
||||
feature_types)
|
||||
|
||||
data, feature_names, feature_types = _maybe_dt_data(data,
|
||||
feature_names,
|
||||
feature_types)
|
||||
label = _maybe_pandas_label(label)
|
||||
label = _maybe_dt_array(label)
|
||||
weight = _maybe_dt_array(weight)
|
||||
|
||||
if isinstance(data, STRING_TYPES):
|
||||
self.handle = ctypes.c_void_p()
|
||||
@@ -279,19 +332,23 @@ class DMatrix(object):
|
||||
self._init_from_csc(data)
|
||||
elif isinstance(data, np.ndarray):
|
||||
self._init_from_npy2d(data, missing, nthread)
|
||||
elif isinstance(data, DataTable):
|
||||
self._init_from_dt(data, nthread)
|
||||
else:
|
||||
try:
|
||||
csr = scipy.sparse.csr_matrix(data)
|
||||
self._init_from_csr(csr)
|
||||
except:
|
||||
raise TypeError('can not initialize DMatrix from {}'.format(type(data).__name__))
|
||||
raise TypeError('can not initialize DMatrix from'
|
||||
' {}'.format(type(data).__name__))
|
||||
|
||||
if label is not None:
|
||||
if isinstance(data, np.ndarray):
|
||||
if isinstance(label, np.ndarray):
|
||||
self.set_label_npy2d(label)
|
||||
else:
|
||||
self.set_label(label)
|
||||
if weight is not None:
|
||||
if isinstance(data, np.ndarray):
|
||||
if isinstance(weight, np.ndarray):
|
||||
self.set_weight_npy2d(weight)
|
||||
else:
|
||||
self.set_weight(weight)
|
||||
@@ -365,6 +422,33 @@ class DMatrix(object):
|
||||
ctypes.byref(self.handle),
|
||||
nthread))
|
||||
|
||||
def _init_from_dt(self, data, nthread):
|
||||
"""
|
||||
Initialize data from a DataTable
|
||||
"""
|
||||
cols = []
|
||||
ptrs = (ctypes.c_void_p * data.ncols)()
|
||||
for icol in range(data.ncols):
|
||||
col = data.internal.column(icol)
|
||||
cols.append(col)
|
||||
# int64_t (void*)
|
||||
ptr = col.data_pointer
|
||||
ptrs[icol] = ctypes.c_void_p(ptr)
|
||||
|
||||
# always return stypes for dt ingestion
|
||||
feature_type_strings = (ctypes.c_char_p * data.ncols)()
|
||||
for icol in range(data.ncols):
|
||||
feature_type_strings[icol] = ctypes.c_char_p(data.stypes[icol].name.encode('utf-8'))
|
||||
|
||||
self.handle = ctypes.c_void_p()
|
||||
|
||||
_check_call(_LIB.XGDMatrixCreateFromDT(
|
||||
ptrs, feature_type_strings,
|
||||
c_bst_ulong(data.shape[0]),
|
||||
c_bst_ulong(data.shape[1]),
|
||||
ctypes.byref(self.handle),
|
||||
nthread))
|
||||
|
||||
def __del__(self):
|
||||
if self.handle is not None:
|
||||
_check_call(_LIB.XGDMatrixFree(self.handle))
|
||||
|
||||
Reference in New Issue
Block a user