Allow import via python datatable. (#3272)
* Allow import via python datatable. * Write unit tests * Refactor dt API functions * Refactor python code * Lint fixes * Address review comments
This commit is contained in:
parent
eecf341ea7
commit
9ac163d0bb
@ -219,6 +219,22 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const float *data, // NOLINT
|
||||
bst_ulong nrow, bst_ulong ncol,
|
||||
float missing, DMatrixHandle *out,
|
||||
int nthread);
|
||||
/*!
|
||||
* \brief create matrix content from python data table
|
||||
* \param data pointer to pointer to column data
|
||||
* \param feature_stypes pointer to strings
|
||||
* \param nrow number of rows
|
||||
* \param ncol number columns
|
||||
* \param out created dmatrix
|
||||
* \param nthread number of threads (up to maximum cores available, if <=0 use all cores)
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGDMatrixCreateFromDT(void** data,
|
||||
const char ** feature_stypes,
|
||||
bst_ulong nrow,
|
||||
bst_ulong ncol,
|
||||
DMatrixHandle* out,
|
||||
int nthread);
|
||||
/*!
|
||||
* \brief create a new dmatrix from sliced content of existing matrix
|
||||
* \param handle instance of data matrix to be sliced
|
||||
@ -261,7 +277,7 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
|
||||
* \brief set uint32 vector to a content in info
|
||||
* \param handle a instance of data matrix
|
||||
* \param field field name
|
||||
* \param array pointer to float vector
|
||||
* \param array pointer to unsigned int vector
|
||||
* \param len length of array
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
|
||||
@ -47,6 +47,18 @@ except ImportError:
|
||||
|
||||
PANDAS_INSTALLED = False
|
||||
|
||||
# dt
|
||||
try:
|
||||
from datatable import DataTable
|
||||
DT_INSTALLED = True
|
||||
except ImportError:
|
||||
|
||||
class DataTable(object):
|
||||
""" dummy for datatable.DataTable """
|
||||
pass
|
||||
|
||||
DT_INSTALLED = False
|
||||
|
||||
# sklearn
|
||||
try:
|
||||
from sklearn.base import BaseEstimator
|
||||
|
||||
@ -4,19 +4,18 @@
|
||||
"""Core XGBoost Library."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import sys
|
||||
import os
|
||||
import ctypes
|
||||
import collections
|
||||
import ctypes
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import scipy.sparse
|
||||
|
||||
from .compat import STRING_TYPES, PY3, DataFrame, MultiIndex, py_str, PANDAS_INSTALLED, DataTable
|
||||
from .libpath import find_lib_path
|
||||
|
||||
from .compat import STRING_TYPES, PY3, DataFrame, MultiIndex, py_str, PANDAS_INSTALLED
|
||||
|
||||
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
|
||||
c_bst_ulong = ctypes.c_uint64
|
||||
|
||||
@ -182,7 +181,7 @@ def _maybe_pandas_data(data, feature_names, feature_types):
|
||||
enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER]
|
||||
|
||||
msg = """DataFrame.dtypes for data must be int, float or bool.
|
||||
Did not expect the data types in fields """
|
||||
Did not expect the data types in fields """
|
||||
raise ValueError(msg + ', '.join(bad_fields))
|
||||
|
||||
if feature_names is None:
|
||||
@ -219,6 +218,54 @@ def _maybe_pandas_label(label):
|
||||
return label
|
||||
|
||||
|
||||
DT_TYPE_MAPPER = {'bool': 'bool', 'int': 'int', 'real': 'float'}
|
||||
|
||||
DT_TYPE_MAPPER2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
|
||||
|
||||
|
||||
def _maybe_dt_data(data, feature_names, feature_types):
|
||||
"""
|
||||
Validate feature names and types if data table
|
||||
"""
|
||||
if not isinstance(data, DataTable):
|
||||
return data, feature_names, feature_types
|
||||
|
||||
data_types_names = tuple(lt.name for lt in data.ltypes)
|
||||
if not all(type_name in DT_TYPE_MAPPER for type_name in data_types_names):
|
||||
bad_fields = [data.names[i] for i, type_name in
|
||||
enumerate(data_types_names) if type_name not in DT_TYPE_MAPPER]
|
||||
|
||||
msg = """DataFrame.types for data must be int, float or bool.
|
||||
Did not expect the data types in fields """
|
||||
raise ValueError(msg + ', '.join(bad_fields))
|
||||
|
||||
if feature_names is None:
|
||||
feature_names = data.names
|
||||
|
||||
# always return stypes for dt ingestion
|
||||
if feature_types is not None:
|
||||
raise ValueError('DataTable has own feature types, cannot pass them in')
|
||||
else:
|
||||
feature_types = np.vectorize(DT_TYPE_MAPPER2.get)(data_types_names)
|
||||
|
||||
return data, feature_names, feature_types
|
||||
|
||||
|
||||
def _maybe_dt_array(array):
|
||||
""" Extract numpy array from single column data table """
|
||||
if not isinstance(array, DataTable) or array is None:
|
||||
return array
|
||||
|
||||
if array.shape[1] > 1:
|
||||
raise ValueError('DataTable for label or weight cannot have multiple columns')
|
||||
|
||||
# below requires new dt version
|
||||
# extract first column
|
||||
array = array.tonumpy()[:, 0].astype('float')
|
||||
|
||||
return array
|
||||
|
||||
|
||||
class DMatrix(object):
|
||||
"""Data Matrix used in XGBoost.
|
||||
|
||||
@ -237,7 +284,7 @@ class DMatrix(object):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
data : string/numpy array/scipy.sparse/pd.DataFrame
|
||||
data : string/numpy array/scipy.sparse/pd.DataFrame/DataTable
|
||||
Data source of DMatrix.
|
||||
When data is string type, it represents the path libsvm format txt file,
|
||||
or binary file that xgboost can read from.
|
||||
@ -266,7 +313,13 @@ class DMatrix(object):
|
||||
data, feature_names, feature_types = _maybe_pandas_data(data,
|
||||
feature_names,
|
||||
feature_types)
|
||||
|
||||
data, feature_names, feature_types = _maybe_dt_data(data,
|
||||
feature_names,
|
||||
feature_types)
|
||||
label = _maybe_pandas_label(label)
|
||||
label = _maybe_dt_array(label)
|
||||
weight = _maybe_dt_array(weight)
|
||||
|
||||
if isinstance(data, STRING_TYPES):
|
||||
self.handle = ctypes.c_void_p()
|
||||
@ -279,19 +332,23 @@ class DMatrix(object):
|
||||
self._init_from_csc(data)
|
||||
elif isinstance(data, np.ndarray):
|
||||
self._init_from_npy2d(data, missing, nthread)
|
||||
elif isinstance(data, DataTable):
|
||||
self._init_from_dt(data, nthread)
|
||||
else:
|
||||
try:
|
||||
csr = scipy.sparse.csr_matrix(data)
|
||||
self._init_from_csr(csr)
|
||||
except:
|
||||
raise TypeError('can not initialize DMatrix from {}'.format(type(data).__name__))
|
||||
raise TypeError('can not initialize DMatrix from'
|
||||
' {}'.format(type(data).__name__))
|
||||
|
||||
if label is not None:
|
||||
if isinstance(data, np.ndarray):
|
||||
if isinstance(label, np.ndarray):
|
||||
self.set_label_npy2d(label)
|
||||
else:
|
||||
self.set_label(label)
|
||||
if weight is not None:
|
||||
if isinstance(data, np.ndarray):
|
||||
if isinstance(weight, np.ndarray):
|
||||
self.set_weight_npy2d(weight)
|
||||
else:
|
||||
self.set_weight(weight)
|
||||
@ -365,6 +422,33 @@ class DMatrix(object):
|
||||
ctypes.byref(self.handle),
|
||||
nthread))
|
||||
|
||||
def _init_from_dt(self, data, nthread):
|
||||
"""
|
||||
Initialize data from a DataTable
|
||||
"""
|
||||
cols = []
|
||||
ptrs = (ctypes.c_void_p * data.ncols)()
|
||||
for icol in range(data.ncols):
|
||||
col = data.internal.column(icol)
|
||||
cols.append(col)
|
||||
# int64_t (void*)
|
||||
ptr = col.data_pointer
|
||||
ptrs[icol] = ctypes.c_void_p(ptr)
|
||||
|
||||
# always return stypes for dt ingestion
|
||||
feature_type_strings = (ctypes.c_char_p * data.ncols)()
|
||||
for icol in range(data.ncols):
|
||||
feature_type_strings[icol] = ctypes.c_char_p(data.stypes[icol].name.encode('utf-8'))
|
||||
|
||||
self.handle = ctypes.c_void_p()
|
||||
|
||||
_check_call(_LIB.XGDMatrixCreateFromDT(
|
||||
ptrs, feature_type_strings,
|
||||
c_bst_ulong(data.shape[0]),
|
||||
c_bst_ulong(data.shape[1]),
|
||||
ctypes.byref(self.handle),
|
||||
nthread))
|
||||
|
||||
def __del__(self):
|
||||
if self.handle is not None:
|
||||
_check_call(_LIB.XGDMatrixFree(self.handle))
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include "../common/io.h"
|
||||
#include "../common/group_data.h"
|
||||
|
||||
|
||||
namespace xgboost {
|
||||
// booster wrapper for backward compatible reason.
|
||||
class Booster {
|
||||
@ -439,6 +440,7 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT
|
||||
const int nthreadmax = std::max(omp_get_num_procs() / 2 - 1, 1);
|
||||
// const int nthreadmax = omp_get_max_threads();
|
||||
if (nthread <= 0) nthread=nthreadmax;
|
||||
int nthread_orig = omp_get_max_threads();
|
||||
omp_set_num_threads(nthread);
|
||||
|
||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||
@ -497,6 +499,144 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT
|
||||
}
|
||||
}
|
||||
}
|
||||
// restore omp state
|
||||
omp_set_num_threads(nthread_orig);
|
||||
|
||||
mat.info.num_nonzero_ = mat.page_.data.size();
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
|
||||
API_END();
|
||||
}
|
||||
|
||||
enum class DTType : uint8_t {
|
||||
kFloat32 = 0,
|
||||
kFloat64 = 1,
|
||||
kBool8 = 2,
|
||||
kInt32 = 3,
|
||||
kInt8 = 4,
|
||||
kInt16 = 5,
|
||||
kInt64 = 6,
|
||||
kUnknown = 7
|
||||
};
|
||||
|
||||
DTType DTGetType(std::string type_string) {
|
||||
if (type_string == "float32") {
|
||||
return DTType::kFloat32;
|
||||
} else if (type_string == "float64") {
|
||||
return DTType::kFloat64;
|
||||
} else if (type_string == "bool8") {
|
||||
return DTType::kBool8;
|
||||
} else if (type_string == "int32") {
|
||||
return DTType::kInt32;
|
||||
} else if (type_string == "int8") {
|
||||
return DTType::kInt8;
|
||||
} else if (type_string == "int16") {
|
||||
return DTType::kInt16;
|
||||
} else if (type_string == "int64") {
|
||||
return DTType::kInt64;
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown data table type.";
|
||||
return DTType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
float DTGetValue(void* column, DTType dt_type, size_t ridx) {
|
||||
float missing = std::numeric_limits<float>::quiet_NaN();
|
||||
switch (dt_type) {
|
||||
case DTType::kFloat32: {
|
||||
float val = reinterpret_cast<float*>(column)[ridx];
|
||||
return std::isfinite(val) ? val : missing;
|
||||
}
|
||||
case DTType::kFloat64: {
|
||||
double val = reinterpret_cast<double*>(column)[ridx];
|
||||
return std::isfinite(val) ? static_cast<float>(val) : missing;
|
||||
}
|
||||
case DTType::kBool8: {
|
||||
bool val = reinterpret_cast<bool*>(column)[ridx];
|
||||
return static_cast<float>(val);
|
||||
}
|
||||
case DTType::kInt32: {
|
||||
int32_t val = reinterpret_cast<int32_t*>(column)[ridx];
|
||||
return val != (-2147483647 - 1) ? static_cast<float>(val) : missing;
|
||||
}
|
||||
case DTType::kInt8: {
|
||||
int8_t val = reinterpret_cast<int8_t*>(column)[ridx];
|
||||
return val != -128 ? static_cast<float>(val) : missing;
|
||||
}
|
||||
case DTType::kInt16: {
|
||||
int16_t val = reinterpret_cast<int16_t*>(column)[ridx];
|
||||
return val != -32768 ? static_cast<float>(val) : missing;
|
||||
}
|
||||
case DTType::kInt64: {
|
||||
int64_t val = reinterpret_cast<int64_t*>(column)[ridx];
|
||||
return val != -9223372036854775807 - 1 ? static_cast<float>(val)
|
||||
: missing;
|
||||
}
|
||||
default: {
|
||||
LOG(FATAL) << "Unknown data table type.";
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromDT(void** data, const char** feature_stypes,
|
||||
xgboost::bst_ulong nrow,
|
||||
xgboost::bst_ulong ncol, DMatrixHandle* out,
|
||||
int nthread) {
|
||||
// avoid openmp unless enough data to be worth it to avoid overhead costs
|
||||
if (nrow * ncol <= 10000 * 50) {
|
||||
nthread = 1;
|
||||
}
|
||||
|
||||
API_BEGIN();
|
||||
const int nthreadmax = std::max(omp_get_num_procs() / 2 - 1, 1);
|
||||
if (nthread <= 0) nthread = nthreadmax;
|
||||
int nthread_orig = omp_get_max_threads();
|
||||
omp_set_num_threads(nthread);
|
||||
|
||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||
data::SimpleCSRSource& mat = *source;
|
||||
mat.page_.offset.resize(1 + nrow);
|
||||
mat.info.num_row_ = nrow;
|
||||
mat.info.num_col_ = ncol;
|
||||
|
||||
#pragma omp parallel num_threads(nthread)
|
||||
{
|
||||
// Count elements per row, column by column
|
||||
for (auto j = 0; j < ncol; ++j) {
|
||||
DTType dtype = DTGetType(feature_stypes[j]);
|
||||
#pragma omp for schedule(static)
|
||||
for (omp_ulong i = 0; i < nrow; ++i) {
|
||||
float val = DTGetValue(data[j], dtype, i);
|
||||
if (!std::isnan(val)) {
|
||||
mat.page_.offset[i + 1]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// do cumulative sum (to avoid otherwise need to copy)
|
||||
PrefixSum(&mat.page_.offset[0], mat.page_.offset.size());
|
||||
|
||||
mat.page_.data.resize(mat.page_.data.size() + mat.page_.offset.back());
|
||||
|
||||
// Fill data matrix (now that know size, no need for slow push_back())
|
||||
std::vector<size_t> position(nrow);
|
||||
#pragma omp parallel num_threads(nthread)
|
||||
{
|
||||
for (xgboost::bst_ulong j = 0; j < ncol; ++j) {
|
||||
DTType dtype = DTGetType(feature_stypes[j]);
|
||||
#pragma omp for schedule(static)
|
||||
for (omp_ulong i = 0; i < nrow; ++i) {
|
||||
float val = DTGetValue(data[j], dtype, i);
|
||||
if (!std::isnan(val)) {
|
||||
mat.page_.data[mat.page_.offset[i] + position[i]] = Entry(j, val);
|
||||
position[i]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// restore omp state
|
||||
omp_set_num_threads(nthread_orig);
|
||||
|
||||
mat.info.num_nonzero_ = mat.page_.data.size();
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
|
||||
|
||||
@ -3,6 +3,34 @@
|
||||
#include <xgboost/c_api.h>
|
||||
#include <xgboost/data.h>
|
||||
|
||||
TEST(c_api, XGDMatrixCreateFromMatDT) {
|
||||
std::vector<int> col0 = {0, -1, 3};
|
||||
std::vector<float> col1 = {-4.0f, 2.0f, 0.0f};
|
||||
const char *col0_type = "int32";
|
||||
const char *col1_type = "float32";
|
||||
std::vector<void *> data = {col0.data(), col1.data()};
|
||||
std::vector<const char *> types = {col0_type, col1_type};
|
||||
DMatrixHandle handle;
|
||||
XGDMatrixCreateFromDT(data.data(), types.data(), 3, 2, &handle,
|
||||
0);
|
||||
std::shared_ptr<xgboost::DMatrix> dmat =
|
||||
*static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||
xgboost::MetaInfo &info = dmat->Info();
|
||||
ASSERT_EQ(info.num_col_, 2);
|
||||
ASSERT_EQ(info.num_row_, 3);
|
||||
ASSERT_EQ(info.num_nonzero_, 6);
|
||||
|
||||
auto iter = dmat->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
auto batch = iter->Value();
|
||||
ASSERT_EQ(batch[0][0].fvalue, 0.0f);
|
||||
ASSERT_EQ(batch[0][1].fvalue, -4.0f);
|
||||
ASSERT_EQ(batch[2][0].fvalue, 3.0f);
|
||||
ASSERT_EQ(batch[2][1].fvalue, 0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(c_api, XGDMatrixCreateFromMat_omp) {
|
||||
std::vector<int> num_rows = {100, 11374, 15000};
|
||||
for (auto row : num_rows) {
|
||||
|
||||
47
tests/python/test_dt.py
Normal file
47
tests/python/test_dt.py
Normal file
@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import unittest
|
||||
|
||||
import testing as tm
|
||||
import xgboost as xgb
|
||||
|
||||
try:
|
||||
import datatable as dt
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
tm._skip_if_no_dt()
|
||||
tm._skip_if_no_pandas()
|
||||
|
||||
|
||||
class TestDataTable(unittest.TestCase):
|
||||
|
||||
def test_dt(self):
|
||||
df = pd.DataFrame([[1, 2., True], [2, 3., False]], columns=['a', 'b', 'c'])
|
||||
dtable = dt.Frame(df)
|
||||
labels = dt.Frame([1, 2])
|
||||
dm = xgb.DMatrix(dtable, label=labels)
|
||||
assert dm.feature_names == ['a', 'b', 'c']
|
||||
assert dm.feature_types == ['int', 'float', 'i']
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 3
|
||||
|
||||
# overwrite feature_names
|
||||
dm = xgb.DMatrix(dtable, label=pd.Series([1, 2]),
|
||||
feature_names=['x', 'y', 'z'])
|
||||
assert dm.feature_names == ['x', 'y', 'z']
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 3
|
||||
|
||||
# incorrect dtypes
|
||||
df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']], columns=['a', 'b', 'c'])
|
||||
dtable = dt.Frame(df)
|
||||
self.assertRaises(ValueError, xgb.DMatrix, dtable)
|
||||
|
||||
df = pd.DataFrame({'A=1': [1, 2, 3], 'A=2': [4, 5, 6]})
|
||||
dtable = dt.Frame(df)
|
||||
dm = xgb.DMatrix(dtable)
|
||||
assert dm.feature_names == ['A=1', 'A=2']
|
||||
assert dm.feature_types == ['int', 'int']
|
||||
assert dm.num_row() == 3
|
||||
assert dm.num_col() == 2
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import nose
|
||||
|
||||
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
|
||||
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED, DT_INSTALLED
|
||||
|
||||
|
||||
def _skip_if_no_sklearn():
|
||||
@ -15,6 +15,11 @@ def _skip_if_no_pandas():
|
||||
raise nose.SkipTest()
|
||||
|
||||
|
||||
def _skip_if_no_dt():
|
||||
if not DT_INSTALLED:
|
||||
raise nose.SkipTest()
|
||||
|
||||
|
||||
def _skip_if_no_matplotlib():
|
||||
try:
|
||||
import matplotlib.pyplot as _ # noqa
|
||||
|
||||
@ -48,6 +48,13 @@ if [ ${TASK} == "python_test" ]; then
|
||||
source activate python3
|
||||
python --version
|
||||
conda install numpy scipy pandas matplotlib nose scikit-learn
|
||||
|
||||
# Install data table from source
|
||||
wget http://releases.llvm.org/5.0.2/clang+llvm-5.0.2-x86_64-linux-gnu-ubuntu-14.04.tar.xz
|
||||
tar xf clang+llvm-5.0.2-x86_64-linux-gnu-ubuntu-14.04.tar.xz
|
||||
export LLVM5=$(pwd)/clang+llvm-5.0.2-x86_64-linux-gnu-ubuntu-14.04
|
||||
python -m pip install datatable --no-binary datatable
|
||||
|
||||
python -m pip install graphviz pytest pytest-cov codecov
|
||||
python -m nose tests/python || exit -1
|
||||
py.test tests/python --cov=python-package/xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user