Allow import via python datatable. (#3272)

* Allow import via python datatable.

* Write unit tests

* Refactor dt API functions

* Refactor python code

* Lint fixes

* Address review comments
This commit is contained in:
PSEUDOTENSOR / Jonathan McKinney 2018-06-20 16:16:18 -04:00 committed by Philip Hyunsu Cho
parent eecf341ea7
commit 9ac163d0bb
8 changed files with 352 additions and 13 deletions

View File

@ -219,6 +219,22 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const float *data, // NOLINT
bst_ulong nrow, bst_ulong ncol, bst_ulong nrow, bst_ulong ncol,
float missing, DMatrixHandle *out, float missing, DMatrixHandle *out,
int nthread); int nthread);
/*!
* \brief create matrix content from python data table
* \param data pointer to pointer to column data
* \param feature_stypes pointer to strings
* \param nrow number of rows
* \param ncol number columns
* \param out created dmatrix
* \param nthread number of threads (up to maximum cores available, if <=0 use all cores)
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixCreateFromDT(void** data,
const char ** feature_stypes,
bst_ulong nrow,
bst_ulong ncol,
DMatrixHandle* out,
int nthread);
/*! /*!
* \brief create a new dmatrix from sliced content of existing matrix * \brief create a new dmatrix from sliced content of existing matrix
* \param handle instance of data matrix to be sliced * \param handle instance of data matrix to be sliced
@ -261,7 +277,7 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
* \brief set uint32 vector to a content in info * \brief set uint32 vector to a content in info
* \param handle a instance of data matrix * \param handle a instance of data matrix
* \param field field name * \param field field name
* \param array pointer to float vector * \param array pointer to unsigned int vector
* \param len length of array * \param len length of array
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */

View File

@ -38,7 +38,7 @@ try:
except ImportError: except ImportError:
class MultiIndex(object): class MultiIndex(object):
""" dummy for pandas.MultiIndex """ """ dummy for pandas.MultiIndex """
pass pass
class DataFrame(object): class DataFrame(object):
@ -47,6 +47,18 @@ except ImportError:
PANDAS_INSTALLED = False PANDAS_INSTALLED = False
# dt
try:
from datatable import DataTable
DT_INSTALLED = True
except ImportError:
class DataTable(object):
""" dummy for datatable.DataTable """
pass
DT_INSTALLED = False
# sklearn # sklearn
try: try:
from sklearn.base import BaseEstimator from sklearn.base import BaseEstimator

View File

@ -4,19 +4,18 @@
"""Core XGBoost Library.""" """Core XGBoost Library."""
from __future__ import absolute_import from __future__ import absolute_import
import sys
import os
import ctypes
import collections import collections
import ctypes
import os
import re import re
import sys
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
from .compat import STRING_TYPES, PY3, DataFrame, MultiIndex, py_str, PANDAS_INSTALLED, DataTable
from .libpath import find_lib_path from .libpath import find_lib_path
from .compat import STRING_TYPES, PY3, DataFrame, MultiIndex, py_str, PANDAS_INSTALLED
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h # c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
c_bst_ulong = ctypes.c_uint64 c_bst_ulong = ctypes.c_uint64
@ -182,7 +181,7 @@ def _maybe_pandas_data(data, feature_names, feature_types):
enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER] enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER]
msg = """DataFrame.dtypes for data must be int, float or bool. msg = """DataFrame.dtypes for data must be int, float or bool.
Did not expect the data types in fields """ Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields)) raise ValueError(msg + ', '.join(bad_fields))
if feature_names is None: if feature_names is None:
@ -219,6 +218,54 @@ def _maybe_pandas_label(label):
return label return label
DT_TYPE_MAPPER = {'bool': 'bool', 'int': 'int', 'real': 'float'}
DT_TYPE_MAPPER2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
def _maybe_dt_data(data, feature_names, feature_types):
"""
Validate feature names and types if data table
"""
if not isinstance(data, DataTable):
return data, feature_names, feature_types
data_types_names = tuple(lt.name for lt in data.ltypes)
if not all(type_name in DT_TYPE_MAPPER for type_name in data_types_names):
bad_fields = [data.names[i] for i, type_name in
enumerate(data_types_names) if type_name not in DT_TYPE_MAPPER]
msg = """DataFrame.types for data must be int, float or bool.
Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
if feature_names is None:
feature_names = data.names
# always return stypes for dt ingestion
if feature_types is not None:
raise ValueError('DataTable has own feature types, cannot pass them in')
else:
feature_types = np.vectorize(DT_TYPE_MAPPER2.get)(data_types_names)
return data, feature_names, feature_types
def _maybe_dt_array(array):
""" Extract numpy array from single column data table """
if not isinstance(array, DataTable) or array is None:
return array
if array.shape[1] > 1:
raise ValueError('DataTable for label or weight cannot have multiple columns')
# below requires new dt version
# extract first column
array = array.tonumpy()[:, 0].astype('float')
return array
class DMatrix(object): class DMatrix(object):
"""Data Matrix used in XGBoost. """Data Matrix used in XGBoost.
@ -237,7 +284,7 @@ class DMatrix(object):
""" """
Parameters Parameters
---------- ----------
data : string/numpy array/scipy.sparse/pd.DataFrame data : string/numpy array/scipy.sparse/pd.DataFrame/DataTable
Data source of DMatrix. Data source of DMatrix.
When data is string type, it represents the path libsvm format txt file, When data is string type, it represents the path libsvm format txt file,
or binary file that xgboost can read from. or binary file that xgboost can read from.
@ -266,7 +313,13 @@ class DMatrix(object):
data, feature_names, feature_types = _maybe_pandas_data(data, data, feature_names, feature_types = _maybe_pandas_data(data,
feature_names, feature_names,
feature_types) feature_types)
data, feature_names, feature_types = _maybe_dt_data(data,
feature_names,
feature_types)
label = _maybe_pandas_label(label) label = _maybe_pandas_label(label)
label = _maybe_dt_array(label)
weight = _maybe_dt_array(weight)
if isinstance(data, STRING_TYPES): if isinstance(data, STRING_TYPES):
self.handle = ctypes.c_void_p() self.handle = ctypes.c_void_p()
@ -279,19 +332,23 @@ class DMatrix(object):
self._init_from_csc(data) self._init_from_csc(data)
elif isinstance(data, np.ndarray): elif isinstance(data, np.ndarray):
self._init_from_npy2d(data, missing, nthread) self._init_from_npy2d(data, missing, nthread)
elif isinstance(data, DataTable):
self._init_from_dt(data, nthread)
else: else:
try: try:
csr = scipy.sparse.csr_matrix(data) csr = scipy.sparse.csr_matrix(data)
self._init_from_csr(csr) self._init_from_csr(csr)
except: except:
raise TypeError('can not initialize DMatrix from {}'.format(type(data).__name__)) raise TypeError('can not initialize DMatrix from'
' {}'.format(type(data).__name__))
if label is not None: if label is not None:
if isinstance(data, np.ndarray): if isinstance(label, np.ndarray):
self.set_label_npy2d(label) self.set_label_npy2d(label)
else: else:
self.set_label(label) self.set_label(label)
if weight is not None: if weight is not None:
if isinstance(data, np.ndarray): if isinstance(weight, np.ndarray):
self.set_weight_npy2d(weight) self.set_weight_npy2d(weight)
else: else:
self.set_weight(weight) self.set_weight(weight)
@ -365,6 +422,33 @@ class DMatrix(object):
ctypes.byref(self.handle), ctypes.byref(self.handle),
nthread)) nthread))
def _init_from_dt(self, data, nthread):
"""
Initialize data from a DataTable
"""
cols = []
ptrs = (ctypes.c_void_p * data.ncols)()
for icol in range(data.ncols):
col = data.internal.column(icol)
cols.append(col)
# int64_t (void*)
ptr = col.data_pointer
ptrs[icol] = ctypes.c_void_p(ptr)
# always return stypes for dt ingestion
feature_type_strings = (ctypes.c_char_p * data.ncols)()
for icol in range(data.ncols):
feature_type_strings[icol] = ctypes.c_char_p(data.stypes[icol].name.encode('utf-8'))
self.handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromDT(
ptrs, feature_type_strings,
c_bst_ulong(data.shape[0]),
c_bst_ulong(data.shape[1]),
ctypes.byref(self.handle),
nthread))
def __del__(self): def __del__(self):
if self.handle is not None: if self.handle is not None:
_check_call(_LIB.XGDMatrixFree(self.handle)) _check_call(_LIB.XGDMatrixFree(self.handle))

View File

@ -18,6 +18,7 @@
#include "../common/io.h" #include "../common/io.h"
#include "../common/group_data.h" #include "../common/group_data.h"
namespace xgboost { namespace xgboost {
// booster wrapper for backward compatible reason. // booster wrapper for backward compatible reason.
class Booster { class Booster {
@ -439,6 +440,7 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT
const int nthreadmax = std::max(omp_get_num_procs() / 2 - 1, 1); const int nthreadmax = std::max(omp_get_num_procs() / 2 - 1, 1);
// const int nthreadmax = omp_get_max_threads(); // const int nthreadmax = omp_get_max_threads();
if (nthread <= 0) nthread=nthreadmax; if (nthread <= 0) nthread=nthreadmax;
int nthread_orig = omp_get_max_threads();
omp_set_num_threads(nthread); omp_set_num_threads(nthread);
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource()); std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
@ -497,12 +499,150 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT
} }
} }
} }
// restore omp state
omp_set_num_threads(nthread_orig);
mat.info.num_nonzero_ = mat.page_.data.size(); mat.info.num_nonzero_ = mat.page_.data.size();
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source))); *out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END(); API_END();
} }
enum class DTType : uint8_t {
kFloat32 = 0,
kFloat64 = 1,
kBool8 = 2,
kInt32 = 3,
kInt8 = 4,
kInt16 = 5,
kInt64 = 6,
kUnknown = 7
};
DTType DTGetType(std::string type_string) {
if (type_string == "float32") {
return DTType::kFloat32;
} else if (type_string == "float64") {
return DTType::kFloat64;
} else if (type_string == "bool8") {
return DTType::kBool8;
} else if (type_string == "int32") {
return DTType::kInt32;
} else if (type_string == "int8") {
return DTType::kInt8;
} else if (type_string == "int16") {
return DTType::kInt16;
} else if (type_string == "int64") {
return DTType::kInt64;
} else {
LOG(FATAL) << "Unknown data table type.";
return DTType::kUnknown;
}
}
float DTGetValue(void* column, DTType dt_type, size_t ridx) {
float missing = std::numeric_limits<float>::quiet_NaN();
switch (dt_type) {
case DTType::kFloat32: {
float val = reinterpret_cast<float*>(column)[ridx];
return std::isfinite(val) ? val : missing;
}
case DTType::kFloat64: {
double val = reinterpret_cast<double*>(column)[ridx];
return std::isfinite(val) ? static_cast<float>(val) : missing;
}
case DTType::kBool8: {
bool val = reinterpret_cast<bool*>(column)[ridx];
return static_cast<float>(val);
}
case DTType::kInt32: {
int32_t val = reinterpret_cast<int32_t*>(column)[ridx];
return val != (-2147483647 - 1) ? static_cast<float>(val) : missing;
}
case DTType::kInt8: {
int8_t val = reinterpret_cast<int8_t*>(column)[ridx];
return val != -128 ? static_cast<float>(val) : missing;
}
case DTType::kInt16: {
int16_t val = reinterpret_cast<int16_t*>(column)[ridx];
return val != -32768 ? static_cast<float>(val) : missing;
}
case DTType::kInt64: {
int64_t val = reinterpret_cast<int64_t*>(column)[ridx];
return val != -9223372036854775807 - 1 ? static_cast<float>(val)
: missing;
}
default: {
LOG(FATAL) << "Unknown data table type.";
return 0.0f;
}
}
}
XGB_DLL int XGDMatrixCreateFromDT(void** data, const char** feature_stypes,
xgboost::bst_ulong nrow,
xgboost::bst_ulong ncol, DMatrixHandle* out,
int nthread) {
// avoid openmp unless enough data to be worth it to avoid overhead costs
if (nrow * ncol <= 10000 * 50) {
nthread = 1;
}
API_BEGIN();
const int nthreadmax = std::max(omp_get_num_procs() / 2 - 1, 1);
if (nthread <= 0) nthread = nthreadmax;
int nthread_orig = omp_get_max_threads();
omp_set_num_threads(nthread);
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
data::SimpleCSRSource& mat = *source;
mat.page_.offset.resize(1 + nrow);
mat.info.num_row_ = nrow;
mat.info.num_col_ = ncol;
#pragma omp parallel num_threads(nthread)
{
// Count elements per row, column by column
for (auto j = 0; j < ncol; ++j) {
DTType dtype = DTGetType(feature_stypes[j]);
#pragma omp for schedule(static)
for (omp_ulong i = 0; i < nrow; ++i) {
float val = DTGetValue(data[j], dtype, i);
if (!std::isnan(val)) {
mat.page_.offset[i + 1]++;
}
}
}
}
// do cumulative sum (to avoid otherwise need to copy)
PrefixSum(&mat.page_.offset[0], mat.page_.offset.size());
mat.page_.data.resize(mat.page_.data.size() + mat.page_.offset.back());
// Fill data matrix (now that know size, no need for slow push_back())
std::vector<size_t> position(nrow);
#pragma omp parallel num_threads(nthread)
{
for (xgboost::bst_ulong j = 0; j < ncol; ++j) {
DTType dtype = DTGetType(feature_stypes[j]);
#pragma omp for schedule(static)
for (omp_ulong i = 0; i < nrow; ++i) {
float val = DTGetValue(data[j], dtype, i);
if (!std::isnan(val)) {
mat.page_.data[mat.page_.offset[i] + position[i]] = Entry(j, val);
position[i]++;
}
}
}
}
// restore omp state
omp_set_num_threads(nthread_orig);
mat.info.num_nonzero_ = mat.page_.data.size();
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END();
}
XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle, XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
const int* idxset, const int* idxset,
xgboost::bst_ulong len, xgboost::bst_ulong len,

View File

@ -3,6 +3,34 @@
#include <xgboost/c_api.h> #include <xgboost/c_api.h>
#include <xgboost/data.h> #include <xgboost/data.h>
TEST(c_api, XGDMatrixCreateFromMatDT) {
std::vector<int> col0 = {0, -1, 3};
std::vector<float> col1 = {-4.0f, 2.0f, 0.0f};
const char *col0_type = "int32";
const char *col1_type = "float32";
std::vector<void *> data = {col0.data(), col1.data()};
std::vector<const char *> types = {col0_type, col1_type};
DMatrixHandle handle;
XGDMatrixCreateFromDT(data.data(), types.data(), 3, 2, &handle,
0);
std::shared_ptr<xgboost::DMatrix> dmat =
*static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
xgboost::MetaInfo &info = dmat->Info();
ASSERT_EQ(info.num_col_, 2);
ASSERT_EQ(info.num_row_, 3);
ASSERT_EQ(info.num_nonzero_, 6);
auto iter = dmat->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
auto batch = iter->Value();
ASSERT_EQ(batch[0][0].fvalue, 0.0f);
ASSERT_EQ(batch[0][1].fvalue, -4.0f);
ASSERT_EQ(batch[2][0].fvalue, 3.0f);
ASSERT_EQ(batch[2][1].fvalue, 0.0f);
}
}
TEST(c_api, XGDMatrixCreateFromMat_omp) { TEST(c_api, XGDMatrixCreateFromMat_omp) {
std::vector<int> num_rows = {100, 11374, 15000}; std::vector<int> num_rows = {100, 11374, 15000};
for (auto row : num_rows) { for (auto row : num_rows) {

47
tests/python/test_dt.py Normal file
View File

@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
import unittest
import testing as tm
import xgboost as xgb
try:
import datatable as dt
import pandas as pd
except ImportError:
pass
tm._skip_if_no_dt()
tm._skip_if_no_pandas()
class TestDataTable(unittest.TestCase):
def test_dt(self):
df = pd.DataFrame([[1, 2., True], [2, 3., False]], columns=['a', 'b', 'c'])
dtable = dt.Frame(df)
labels = dt.Frame([1, 2])
dm = xgb.DMatrix(dtable, label=labels)
assert dm.feature_names == ['a', 'b', 'c']
assert dm.feature_types == ['int', 'float', 'i']
assert dm.num_row() == 2
assert dm.num_col() == 3
# overwrite feature_names
dm = xgb.DMatrix(dtable, label=pd.Series([1, 2]),
feature_names=['x', 'y', 'z'])
assert dm.feature_names == ['x', 'y', 'z']
assert dm.num_row() == 2
assert dm.num_col() == 3
# incorrect dtypes
df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']], columns=['a', 'b', 'c'])
dtable = dt.Frame(df)
self.assertRaises(ValueError, xgb.DMatrix, dtable)
df = pd.DataFrame({'A=1': [1, 2, 3], 'A=2': [4, 5, 6]})
dtable = dt.Frame(df)
dm = xgb.DMatrix(dtable)
assert dm.feature_names == ['A=1', 'A=2']
assert dm.feature_types == ['int', 'int']
assert dm.num_row() == 3
assert dm.num_col() == 2

View File

@ -2,7 +2,7 @@
import nose import nose
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED, DT_INSTALLED
def _skip_if_no_sklearn(): def _skip_if_no_sklearn():
@ -15,6 +15,11 @@ def _skip_if_no_pandas():
raise nose.SkipTest() raise nose.SkipTest()
def _skip_if_no_dt():
if not DT_INSTALLED:
raise nose.SkipTest()
def _skip_if_no_matplotlib(): def _skip_if_no_matplotlib():
try: try:
import matplotlib.pyplot as _ # noqa import matplotlib.pyplot as _ # noqa

View File

@ -48,6 +48,13 @@ if [ ${TASK} == "python_test" ]; then
source activate python3 source activate python3
python --version python --version
conda install numpy scipy pandas matplotlib nose scikit-learn conda install numpy scipy pandas matplotlib nose scikit-learn
# Install data table from source
wget http://releases.llvm.org/5.0.2/clang+llvm-5.0.2-x86_64-linux-gnu-ubuntu-14.04.tar.xz
tar xf clang+llvm-5.0.2-x86_64-linux-gnu-ubuntu-14.04.tar.xz
export LLVM5=$(pwd)/clang+llvm-5.0.2-x86_64-linux-gnu-ubuntu-14.04
python -m pip install datatable --no-binary datatable
python -m pip install graphviz pytest pytest-cov codecov python -m pip install graphviz pytest pytest-cov codecov
python -m nose tests/python || exit -1 python -m nose tests/python || exit -1
py.test tests/python --cov=python-package/xgboost py.test tests/python --cov=python-package/xgboost