Allow import via python datatable. (#3272)

* Allow import via python datatable.

* Write unit tests

* Refactor dt API functions

* Refactor python code

* Lint fixes

* Address review comments
This commit is contained in:
PSEUDOTENSOR / Jonathan McKinney 2018-06-20 16:16:18 -04:00 committed by Philip Hyunsu Cho
parent eecf341ea7
commit 9ac163d0bb
8 changed files with 352 additions and 13 deletions

View File

@ -219,6 +219,22 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const float *data, // NOLINT
bst_ulong nrow, bst_ulong ncol,
float missing, DMatrixHandle *out,
int nthread);
/*!
* \brief create matrix content from python data table
* \param data pointer to pointer to column data
* \param feature_stypes pointer to strings
* \param nrow number of rows
* \param ncol number columns
* \param out created dmatrix
* \param nthread number of threads (up to maximum cores available, if <=0 use all cores)
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixCreateFromDT(void** data,
const char ** feature_stypes,
bst_ulong nrow,
bst_ulong ncol,
DMatrixHandle* out,
int nthread);
/*!
* \brief create a new dmatrix from sliced content of existing matrix
* \param handle instance of data matrix to be sliced
@ -261,7 +277,7 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
* \brief set uint32 vector to a content in info
* \param handle a instance of data matrix
* \param field field name
* \param array pointer to float vector
* \param array pointer to unsigned int vector
* \param len length of array
* \return 0 when success, -1 when failure happens
*/

View File

@ -38,7 +38,7 @@ try:
except ImportError:
class MultiIndex(object):
""" dummy for pandas.MultiIndex """
""" dummy for pandas.MultiIndex """
pass
class DataFrame(object):
@ -47,6 +47,18 @@ except ImportError:
PANDAS_INSTALLED = False
# dt
try:
from datatable import DataTable
DT_INSTALLED = True
except ImportError:
class DataTable(object):
""" dummy for datatable.DataTable """
pass
DT_INSTALLED = False
# sklearn
try:
from sklearn.base import BaseEstimator

View File

@ -4,19 +4,18 @@
"""Core XGBoost Library."""
from __future__ import absolute_import
import sys
import os
import ctypes
import collections
import ctypes
import os
import re
import sys
import numpy as np
import scipy.sparse
from .compat import STRING_TYPES, PY3, DataFrame, MultiIndex, py_str, PANDAS_INSTALLED, DataTable
from .libpath import find_lib_path
from .compat import STRING_TYPES, PY3, DataFrame, MultiIndex, py_str, PANDAS_INSTALLED
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
c_bst_ulong = ctypes.c_uint64
@ -182,7 +181,7 @@ def _maybe_pandas_data(data, feature_names, feature_types):
enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER]
msg = """DataFrame.dtypes for data must be int, float or bool.
Did not expect the data types in fields """
Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
if feature_names is None:
@ -219,6 +218,54 @@ def _maybe_pandas_label(label):
return label
DT_TYPE_MAPPER = {'bool': 'bool', 'int': 'int', 'real': 'float'}
DT_TYPE_MAPPER2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
def _maybe_dt_data(data, feature_names, feature_types):
"""
Validate feature names and types if data table
"""
if not isinstance(data, DataTable):
return data, feature_names, feature_types
data_types_names = tuple(lt.name for lt in data.ltypes)
if not all(type_name in DT_TYPE_MAPPER for type_name in data_types_names):
bad_fields = [data.names[i] for i, type_name in
enumerate(data_types_names) if type_name not in DT_TYPE_MAPPER]
msg = """DataFrame.types for data must be int, float or bool.
Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
if feature_names is None:
feature_names = data.names
# always return stypes for dt ingestion
if feature_types is not None:
raise ValueError('DataTable has own feature types, cannot pass them in')
else:
feature_types = np.vectorize(DT_TYPE_MAPPER2.get)(data_types_names)
return data, feature_names, feature_types
def _maybe_dt_array(array):
""" Extract numpy array from single column data table """
if not isinstance(array, DataTable) or array is None:
return array
if array.shape[1] > 1:
raise ValueError('DataTable for label or weight cannot have multiple columns')
# below requires new dt version
# extract first column
array = array.tonumpy()[:, 0].astype('float')
return array
class DMatrix(object):
"""Data Matrix used in XGBoost.
@ -237,7 +284,7 @@ class DMatrix(object):
"""
Parameters
----------
data : string/numpy array/scipy.sparse/pd.DataFrame
data : string/numpy array/scipy.sparse/pd.DataFrame/DataTable
Data source of DMatrix.
When data is string type, it represents the path libsvm format txt file,
or binary file that xgboost can read from.
@ -266,7 +313,13 @@ class DMatrix(object):
data, feature_names, feature_types = _maybe_pandas_data(data,
feature_names,
feature_types)
data, feature_names, feature_types = _maybe_dt_data(data,
feature_names,
feature_types)
label = _maybe_pandas_label(label)
label = _maybe_dt_array(label)
weight = _maybe_dt_array(weight)
if isinstance(data, STRING_TYPES):
self.handle = ctypes.c_void_p()
@ -279,19 +332,23 @@ class DMatrix(object):
self._init_from_csc(data)
elif isinstance(data, np.ndarray):
self._init_from_npy2d(data, missing, nthread)
elif isinstance(data, DataTable):
self._init_from_dt(data, nthread)
else:
try:
csr = scipy.sparse.csr_matrix(data)
self._init_from_csr(csr)
except:
raise TypeError('can not initialize DMatrix from {}'.format(type(data).__name__))
raise TypeError('can not initialize DMatrix from'
' {}'.format(type(data).__name__))
if label is not None:
if isinstance(data, np.ndarray):
if isinstance(label, np.ndarray):
self.set_label_npy2d(label)
else:
self.set_label(label)
if weight is not None:
if isinstance(data, np.ndarray):
if isinstance(weight, np.ndarray):
self.set_weight_npy2d(weight)
else:
self.set_weight(weight)
@ -365,6 +422,33 @@ class DMatrix(object):
ctypes.byref(self.handle),
nthread))
def _init_from_dt(self, data, nthread):
"""
Initialize data from a DataTable
"""
cols = []
ptrs = (ctypes.c_void_p * data.ncols)()
for icol in range(data.ncols):
col = data.internal.column(icol)
cols.append(col)
# int64_t (void*)
ptr = col.data_pointer
ptrs[icol] = ctypes.c_void_p(ptr)
# always return stypes for dt ingestion
feature_type_strings = (ctypes.c_char_p * data.ncols)()
for icol in range(data.ncols):
feature_type_strings[icol] = ctypes.c_char_p(data.stypes[icol].name.encode('utf-8'))
self.handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromDT(
ptrs, feature_type_strings,
c_bst_ulong(data.shape[0]),
c_bst_ulong(data.shape[1]),
ctypes.byref(self.handle),
nthread))
def __del__(self):
if self.handle is not None:
_check_call(_LIB.XGDMatrixFree(self.handle))

View File

@ -18,6 +18,7 @@
#include "../common/io.h"
#include "../common/group_data.h"
namespace xgboost {
// booster wrapper for backward compatible reason.
class Booster {
@ -439,6 +440,7 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT
const int nthreadmax = std::max(omp_get_num_procs() / 2 - 1, 1);
// const int nthreadmax = omp_get_max_threads();
if (nthread <= 0) nthread=nthreadmax;
int nthread_orig = omp_get_max_threads();
omp_set_num_threads(nthread);
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
@ -497,12 +499,150 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT
}
}
}
// restore omp state
omp_set_num_threads(nthread_orig);
mat.info.num_nonzero_ = mat.page_.data.size();
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END();
}
enum class DTType : uint8_t {
kFloat32 = 0,
kFloat64 = 1,
kBool8 = 2,
kInt32 = 3,
kInt8 = 4,
kInt16 = 5,
kInt64 = 6,
kUnknown = 7
};
DTType DTGetType(std::string type_string) {
if (type_string == "float32") {
return DTType::kFloat32;
} else if (type_string == "float64") {
return DTType::kFloat64;
} else if (type_string == "bool8") {
return DTType::kBool8;
} else if (type_string == "int32") {
return DTType::kInt32;
} else if (type_string == "int8") {
return DTType::kInt8;
} else if (type_string == "int16") {
return DTType::kInt16;
} else if (type_string == "int64") {
return DTType::kInt64;
} else {
LOG(FATAL) << "Unknown data table type.";
return DTType::kUnknown;
}
}
float DTGetValue(void* column, DTType dt_type, size_t ridx) {
float missing = std::numeric_limits<float>::quiet_NaN();
switch (dt_type) {
case DTType::kFloat32: {
float val = reinterpret_cast<float*>(column)[ridx];
return std::isfinite(val) ? val : missing;
}
case DTType::kFloat64: {
double val = reinterpret_cast<double*>(column)[ridx];
return std::isfinite(val) ? static_cast<float>(val) : missing;
}
case DTType::kBool8: {
bool val = reinterpret_cast<bool*>(column)[ridx];
return static_cast<float>(val);
}
case DTType::kInt32: {
int32_t val = reinterpret_cast<int32_t*>(column)[ridx];
return val != (-2147483647 - 1) ? static_cast<float>(val) : missing;
}
case DTType::kInt8: {
int8_t val = reinterpret_cast<int8_t*>(column)[ridx];
return val != -128 ? static_cast<float>(val) : missing;
}
case DTType::kInt16: {
int16_t val = reinterpret_cast<int16_t*>(column)[ridx];
return val != -32768 ? static_cast<float>(val) : missing;
}
case DTType::kInt64: {
int64_t val = reinterpret_cast<int64_t*>(column)[ridx];
return val != -9223372036854775807 - 1 ? static_cast<float>(val)
: missing;
}
default: {
LOG(FATAL) << "Unknown data table type.";
return 0.0f;
}
}
}
XGB_DLL int XGDMatrixCreateFromDT(void** data, const char** feature_stypes,
xgboost::bst_ulong nrow,
xgboost::bst_ulong ncol, DMatrixHandle* out,
int nthread) {
// avoid openmp unless enough data to be worth it to avoid overhead costs
if (nrow * ncol <= 10000 * 50) {
nthread = 1;
}
API_BEGIN();
const int nthreadmax = std::max(omp_get_num_procs() / 2 - 1, 1);
if (nthread <= 0) nthread = nthreadmax;
int nthread_orig = omp_get_max_threads();
omp_set_num_threads(nthread);
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
data::SimpleCSRSource& mat = *source;
mat.page_.offset.resize(1 + nrow);
mat.info.num_row_ = nrow;
mat.info.num_col_ = ncol;
#pragma omp parallel num_threads(nthread)
{
// Count elements per row, column by column
for (auto j = 0; j < ncol; ++j) {
DTType dtype = DTGetType(feature_stypes[j]);
#pragma omp for schedule(static)
for (omp_ulong i = 0; i < nrow; ++i) {
float val = DTGetValue(data[j], dtype, i);
if (!std::isnan(val)) {
mat.page_.offset[i + 1]++;
}
}
}
}
// do cumulative sum (to avoid otherwise need to copy)
PrefixSum(&mat.page_.offset[0], mat.page_.offset.size());
mat.page_.data.resize(mat.page_.data.size() + mat.page_.offset.back());
// Fill data matrix (now that know size, no need for slow push_back())
std::vector<size_t> position(nrow);
#pragma omp parallel num_threads(nthread)
{
for (xgboost::bst_ulong j = 0; j < ncol; ++j) {
DTType dtype = DTGetType(feature_stypes[j]);
#pragma omp for schedule(static)
for (omp_ulong i = 0; i < nrow; ++i) {
float val = DTGetValue(data[j], dtype, i);
if (!std::isnan(val)) {
mat.page_.data[mat.page_.offset[i] + position[i]] = Entry(j, val);
position[i]++;
}
}
}
}
// restore omp state
omp_set_num_threads(nthread_orig);
mat.info.num_nonzero_ = mat.page_.data.size();
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END();
}
XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
const int* idxset,
xgboost::bst_ulong len,

View File

@ -3,6 +3,34 @@
#include <xgboost/c_api.h>
#include <xgboost/data.h>
TEST(c_api, XGDMatrixCreateFromMatDT) {
std::vector<int> col0 = {0, -1, 3};
std::vector<float> col1 = {-4.0f, 2.0f, 0.0f};
const char *col0_type = "int32";
const char *col1_type = "float32";
std::vector<void *> data = {col0.data(), col1.data()};
std::vector<const char *> types = {col0_type, col1_type};
DMatrixHandle handle;
XGDMatrixCreateFromDT(data.data(), types.data(), 3, 2, &handle,
0);
std::shared_ptr<xgboost::DMatrix> dmat =
*static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
xgboost::MetaInfo &info = dmat->Info();
ASSERT_EQ(info.num_col_, 2);
ASSERT_EQ(info.num_row_, 3);
ASSERT_EQ(info.num_nonzero_, 6);
auto iter = dmat->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
auto batch = iter->Value();
ASSERT_EQ(batch[0][0].fvalue, 0.0f);
ASSERT_EQ(batch[0][1].fvalue, -4.0f);
ASSERT_EQ(batch[2][0].fvalue, 3.0f);
ASSERT_EQ(batch[2][1].fvalue, 0.0f);
}
}
TEST(c_api, XGDMatrixCreateFromMat_omp) {
std::vector<int> num_rows = {100, 11374, 15000};
for (auto row : num_rows) {

47
tests/python/test_dt.py Normal file
View File

@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
import unittest
import testing as tm
import xgboost as xgb
try:
import datatable as dt
import pandas as pd
except ImportError:
pass
tm._skip_if_no_dt()
tm._skip_if_no_pandas()
class TestDataTable(unittest.TestCase):
def test_dt(self):
df = pd.DataFrame([[1, 2., True], [2, 3., False]], columns=['a', 'b', 'c'])
dtable = dt.Frame(df)
labels = dt.Frame([1, 2])
dm = xgb.DMatrix(dtable, label=labels)
assert dm.feature_names == ['a', 'b', 'c']
assert dm.feature_types == ['int', 'float', 'i']
assert dm.num_row() == 2
assert dm.num_col() == 3
# overwrite feature_names
dm = xgb.DMatrix(dtable, label=pd.Series([1, 2]),
feature_names=['x', 'y', 'z'])
assert dm.feature_names == ['x', 'y', 'z']
assert dm.num_row() == 2
assert dm.num_col() == 3
# incorrect dtypes
df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']], columns=['a', 'b', 'c'])
dtable = dt.Frame(df)
self.assertRaises(ValueError, xgb.DMatrix, dtable)
df = pd.DataFrame({'A=1': [1, 2, 3], 'A=2': [4, 5, 6]})
dtable = dt.Frame(df)
dm = xgb.DMatrix(dtable)
assert dm.feature_names == ['A=1', 'A=2']
assert dm.feature_types == ['int', 'int']
assert dm.num_row() == 3
assert dm.num_col() == 2

View File

@ -2,7 +2,7 @@
import nose
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED, DT_INSTALLED
def _skip_if_no_sklearn():
@ -15,6 +15,11 @@ def _skip_if_no_pandas():
raise nose.SkipTest()
def _skip_if_no_dt():
if not DT_INSTALLED:
raise nose.SkipTest()
def _skip_if_no_matplotlib():
try:
import matplotlib.pyplot as _ # noqa

View File

@ -48,6 +48,13 @@ if [ ${TASK} == "python_test" ]; then
source activate python3
python --version
conda install numpy scipy pandas matplotlib nose scikit-learn
# Install data table from source
wget http://releases.llvm.org/5.0.2/clang+llvm-5.0.2-x86_64-linux-gnu-ubuntu-14.04.tar.xz
tar xf clang+llvm-5.0.2-x86_64-linux-gnu-ubuntu-14.04.tar.xz
export LLVM5=$(pwd)/clang+llvm-5.0.2-x86_64-linux-gnu-ubuntu-14.04
python -m pip install datatable --no-binary datatable
python -m pip install graphviz pytest pytest-cov codecov
python -m nose tests/python || exit -1
py.test tests/python --cov=python-package/xgboost