Take datatable as row major input. (#8472)
* Take datatable as row major input. Try to avoid a transform with dense table.
This commit is contained in:
parent
284dcf8d22
commit
e07245f110
@ -44,8 +44,9 @@ _matrix_meta = {"base_margin", "label"}
|
||||
def _warn_unused_missing(data: DataType, missing: Optional[FloatCompatible]) -> None:
|
||||
if (missing is not None) and (not np.isnan(missing)):
|
||||
warnings.warn(
|
||||
'`missing` is not used for current input data type:' +
|
||||
str(type(data)), UserWarning)
|
||||
"`missing` is not used for current input data type:" + str(type(data)),
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
|
||||
def _check_complex(data: DataType) -> None:
|
||||
@ -459,12 +460,9 @@ def _from_pandas_series(
|
||||
|
||||
|
||||
def _is_dt_df(data: DataType) -> bool:
|
||||
return lazy_isinstance(data, 'datatable', 'Frame') or \
|
||||
lazy_isinstance(data, 'datatable', 'DataTable')
|
||||
|
||||
|
||||
_dt_type_mapper = {'bool': 'bool', 'int': 'int', 'real': 'float'}
|
||||
_dt_type_mapper2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
|
||||
return lazy_isinstance(data, "datatable", "Frame") or lazy_isinstance(
|
||||
data, "datatable", "DataTable"
|
||||
)
|
||||
|
||||
|
||||
def _transform_dt_df(
|
||||
@ -475,8 +473,10 @@ def _transform_dt_df(
|
||||
meta_type: Optional[NumpyDType] = None,
|
||||
) -> Tuple[np.ndarray, Optional[FeatureNames], Optional[FeatureTypes]]:
|
||||
"""Validate feature names and types if data table"""
|
||||
_dt_type_mapper = {"bool": "bool", "int": "int", "real": "float"}
|
||||
_dt_type_mapper2 = {"bool": "i", "int": "int", "real": "float"}
|
||||
if meta and data.shape[1] > 1:
|
||||
raise ValueError('DataTable for meta info cannot have multiple columns')
|
||||
raise ValueError("DataTable for meta info cannot have multiple columns")
|
||||
if meta:
|
||||
meta_type = "float" if meta_type is None else meta_type
|
||||
# below requires new dt version
|
||||
@ -485,23 +485,23 @@ def _transform_dt_df(
|
||||
return data, None, None
|
||||
|
||||
data_types_names = tuple(lt.name for lt in data.ltypes)
|
||||
bad_fields = [data.names[i]
|
||||
bad_fields = [
|
||||
data.names[i]
|
||||
for i, type_name in enumerate(data_types_names)
|
||||
if type_name not in _dt_type_mapper]
|
||||
if type_name not in _dt_type_mapper
|
||||
]
|
||||
if bad_fields:
|
||||
msg = """DataFrame.types for data must be int, float or bool.
|
||||
Did not expect the data types in fields """
|
||||
raise ValueError(msg + ', '.join(bad_fields))
|
||||
raise ValueError(msg + ", ".join(bad_fields))
|
||||
|
||||
if feature_names is None and meta is None:
|
||||
feature_names = data.names
|
||||
|
||||
# always return stypes for dt ingestion
|
||||
if feature_types is not None:
|
||||
raise ValueError(
|
||||
'DataTable has own feature types, cannot pass them in.')
|
||||
feature_types = np.vectorize(_dt_type_mapper2.get)(
|
||||
data_types_names).tolist()
|
||||
raise ValueError("DataTable has own feature types, cannot pass them in.")
|
||||
feature_types = np.vectorize(_dt_type_mapper2.get)(data_types_names).tolist()
|
||||
|
||||
return data, feature_names, feature_types
|
||||
|
||||
@ -517,7 +517,8 @@ def _from_dt_df(
|
||||
if enable_categorical:
|
||||
raise ValueError("categorical data in datatable is not supported yet.")
|
||||
data, feature_names, feature_types = _transform_dt_df(
|
||||
data, feature_names, feature_types, None, None)
|
||||
data, feature_names, feature_types, None, None
|
||||
)
|
||||
|
||||
ptrs = (ctypes.c_void_p * data.ncols)()
|
||||
if hasattr(data, "internal") and hasattr(data.internal, "column"):
|
||||
@ -531,6 +532,7 @@ def _from_dt_df(
|
||||
from datatable.internal import (
|
||||
frame_column_data_r, # pylint: disable=no-name-in-module
|
||||
)
|
||||
|
||||
for icol in range(data.ncols):
|
||||
ptrs[icol] = frame_column_data_r(data, icol)
|
||||
|
||||
@ -538,16 +540,21 @@ def _from_dt_df(
|
||||
feature_type_strings = (ctypes.c_char_p * data.ncols)()
|
||||
for icol in range(data.ncols):
|
||||
feature_type_strings[icol] = ctypes.c_char_p(
|
||||
data.stypes[icol].name.encode('utf-8'))
|
||||
data.stypes[icol].name.encode("utf-8")
|
||||
)
|
||||
|
||||
_warn_unused_missing(data, missing)
|
||||
handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGDMatrixCreateFromDT(
|
||||
ptrs, feature_type_strings,
|
||||
_check_call(
|
||||
_LIB.XGDMatrixCreateFromDT(
|
||||
ptrs,
|
||||
feature_type_strings,
|
||||
c_bst_ulong(data.shape[0]),
|
||||
c_bst_ulong(data.shape[1]),
|
||||
ctypes.byref(handle),
|
||||
ctypes.c_int(nthread)))
|
||||
ctypes.c_int(nthread),
|
||||
)
|
||||
)
|
||||
return handle, feature_names, feature_types
|
||||
|
||||
|
||||
|
||||
@ -473,16 +473,7 @@ class CSCAdapter : public detail::SingleBatchDataIter<CSCAdapterBatch> {
|
||||
};
|
||||
|
||||
class DataTableAdapterBatch : public detail::NoMetaInfo {
|
||||
public:
|
||||
DataTableAdapterBatch(void** data, const char** feature_stypes,
|
||||
size_t num_rows, size_t num_features)
|
||||
: data_(data),
|
||||
feature_stypes_(feature_stypes),
|
||||
num_features_(num_features),
|
||||
num_rows_(num_rows) {}
|
||||
|
||||
private:
|
||||
enum class DTType : uint8_t {
|
||||
enum class DTType : std::uint8_t {
|
||||
kFloat32 = 0,
|
||||
kFloat64 = 1,
|
||||
kBool8 = 2,
|
||||
@ -493,7 +484,7 @@ class DataTableAdapterBatch : public detail::NoMetaInfo {
|
||||
kUnknown = 7
|
||||
};
|
||||
|
||||
DTType DTGetType(std::string type_string) const {
|
||||
static DTType DTGetType(std::string type_string) {
|
||||
if (type_string == "float32") {
|
||||
return DTType::kFloat32;
|
||||
} else if (type_string == "float64") {
|
||||
@ -514,8 +505,23 @@ class DataTableAdapterBatch : public detail::NoMetaInfo {
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
DataTableAdapterBatch(void const* const* const data, char const* const* feature_stypes,
|
||||
std::size_t num_rows, std::size_t num_features)
|
||||
: data_(data), num_rows_(num_rows) {
|
||||
CHECK(feature_types_.empty());
|
||||
std::transform(feature_stypes, feature_stypes + num_features,
|
||||
std::back_inserter(feature_types_),
|
||||
[](char const* stype) { return DTGetType(stype); });
|
||||
}
|
||||
|
||||
private:
|
||||
class Line {
|
||||
float DTGetValue(const void* column, DTType dt_type, size_t ridx) const {
|
||||
std::size_t row_idx_;
|
||||
void const* const* const data_;
|
||||
std::vector<DTType> const& feature_types_;
|
||||
|
||||
float DTGetValue(void const* column, DTType dt_type, std::size_t ridx) const {
|
||||
float missing = std::numeric_limits<float>::quiet_NaN();
|
||||
switch (dt_type) {
|
||||
case DTType::kFloat32: {
|
||||
@ -544,8 +550,7 @@ class DataTableAdapterBatch : public detail::NoMetaInfo {
|
||||
}
|
||||
case DTType::kInt64: {
|
||||
int64_t val = reinterpret_cast<const int64_t*>(column)[ridx];
|
||||
return val != -9223372036854775807 - 1 ? static_cast<float>(val)
|
||||
: missing;
|
||||
return val != -9223372036854775807 - 1 ? static_cast<float>(val) : missing;
|
||||
}
|
||||
default: {
|
||||
LOG(FATAL) << "Unknown data table type.";
|
||||
@ -555,51 +560,41 @@ class DataTableAdapterBatch : public detail::NoMetaInfo {
|
||||
}
|
||||
|
||||
public:
|
||||
Line(DTType type, size_t size, size_t column_idx, const void* column)
|
||||
: type_(type), size_(size), column_idx_(column_idx), column_(column) {}
|
||||
|
||||
size_t Size() const { return size_; }
|
||||
COOTuple GetElement(size_t idx) const {
|
||||
return COOTuple{idx, column_idx_, DTGetValue(column_, type_, idx)};
|
||||
Line(std::size_t ridx, void const* const* const data, std::vector<DTType> const& ft)
|
||||
: row_idx_{ridx}, data_{data}, feature_types_{ft} {}
|
||||
std::size_t Size() const { return feature_types_.size(); }
|
||||
COOTuple GetElement(std::size_t idx) const {
|
||||
return COOTuple{row_idx_, idx, DTGetValue(data_[idx], feature_types_[idx], row_idx_)};
|
||||
}
|
||||
|
||||
private:
|
||||
DTType type_;
|
||||
size_t size_;
|
||||
size_t column_idx_;
|
||||
const void* column_;
|
||||
};
|
||||
|
||||
public:
|
||||
size_t Size() const { return num_features_; }
|
||||
const Line GetLine(size_t idx) const {
|
||||
return Line(DTGetType(feature_stypes_[idx]), num_rows_, idx, data_[idx]);
|
||||
}
|
||||
static constexpr bool kIsRowMajor = false;
|
||||
size_t Size() const { return num_rows_; }
|
||||
const Line GetLine(std::size_t ridx) const { return {ridx, data_, feature_types_}; }
|
||||
static constexpr bool kIsRowMajor = true;
|
||||
|
||||
private:
|
||||
void** data_;
|
||||
const char** feature_stypes_;
|
||||
size_t num_features_;
|
||||
size_t num_rows_;
|
||||
void const* const* const data_;
|
||||
|
||||
std::vector<DTType> feature_types_;
|
||||
std::size_t num_rows_;
|
||||
};
|
||||
|
||||
class DataTableAdapter
|
||||
: public detail::SingleBatchDataIter<DataTableAdapterBatch> {
|
||||
class DataTableAdapter : public detail::SingleBatchDataIter<DataTableAdapterBatch> {
|
||||
public:
|
||||
DataTableAdapter(void** data, const char** feature_stypes, size_t num_rows,
|
||||
size_t num_features)
|
||||
DataTableAdapter(void** data, const char** feature_stypes, std::size_t num_rows,
|
||||
std::size_t num_features)
|
||||
: batch_(data, feature_stypes, num_rows, num_features),
|
||||
num_rows_(num_rows),
|
||||
num_columns_(num_features) {}
|
||||
const DataTableAdapterBatch& Value() const override { return batch_; }
|
||||
size_t NumRows() const { return num_rows_; }
|
||||
size_t NumColumns() const { return num_columns_; }
|
||||
std::size_t NumRows() const { return num_rows_; }
|
||||
std::size_t NumColumns() const { return num_columns_; }
|
||||
|
||||
private:
|
||||
DataTableAdapterBatch batch_;
|
||||
size_t num_rows_;
|
||||
size_t num_columns_;
|
||||
std::size_t num_rows_;
|
||||
std::size_t num_columns_;
|
||||
};
|
||||
|
||||
class FileAdapterBatch {
|
||||
|
||||
@ -144,6 +144,7 @@ def main(args: argparse.Namespace) -> None:
|
||||
# tests
|
||||
"tests/python/test_config.py",
|
||||
"tests/python/test_data_iterator.py",
|
||||
"tests/python/test_dt.py",
|
||||
"tests/python/test_quantile_dmatrix.py",
|
||||
"tests/python-gpu/test_gpu_data_iterator.py",
|
||||
"tests/test_distributed/test_with_spark/",
|
||||
@ -194,6 +195,7 @@ def main(args: argparse.Namespace) -> None:
|
||||
"demo/guide-python/external_memory.py",
|
||||
"demo/guide-python/cat_in_the_dat.py",
|
||||
# tests
|
||||
"tests/python/test_dt.py",
|
||||
"tests/python/test_data_iterator.py",
|
||||
"tests/python-gpu/test_gpu_data_iterator.py",
|
||||
"tests/test_distributed/test_with_spark/test_data.py",
|
||||
|
||||
@ -2,52 +2,40 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
|
||||
try:
|
||||
import datatable as dt
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
tm.no_dt()['condition'] or tm.no_pandas()['condition'],
|
||||
reason=tm.no_dt()['reason'] + ' or ' + tm.no_pandas()['reason'])
|
||||
dt = pytest.importorskip("datatable")
|
||||
pd = pytest.importorskip("pandas")
|
||||
|
||||
|
||||
class TestDataTable:
|
||||
|
||||
def test_dt(self):
|
||||
df = pd.DataFrame([[1, 2., True], [2, 3., False]],
|
||||
columns=['a', 'b', 'c'])
|
||||
def test_dt(self) -> None:
|
||||
df = pd.DataFrame([[1, 2.0, True], [2, 3.0, False]], columns=["a", "b", "c"])
|
||||
dtable = dt.Frame(df)
|
||||
labels = dt.Frame([1, 2])
|
||||
dm = xgb.DMatrix(dtable, label=labels)
|
||||
assert dm.feature_names == ['a', 'b', 'c']
|
||||
assert dm.feature_types == ['int', 'float', 'i']
|
||||
assert dm.feature_names == ["a", "b", "c"]
|
||||
assert dm.feature_types == ["int", "float", "i"]
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 3
|
||||
|
||||
np.testing.assert_array_equal(np.array([1, 2]), dm.get_label())
|
||||
|
||||
# overwrite feature_names
|
||||
dm = xgb.DMatrix(dtable, label=pd.Series([1, 2]),
|
||||
feature_names=['x', 'y', 'z'])
|
||||
assert dm.feature_names == ['x', 'y', 'z']
|
||||
dm = xgb.DMatrix(dtable, label=pd.Series([1, 2]), feature_names=["x", "y", "z"])
|
||||
assert dm.feature_names == ["x", "y", "z"]
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 3
|
||||
|
||||
# incorrect dtypes
|
||||
df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']],
|
||||
columns=['a', 'b', 'c'])
|
||||
df = pd.DataFrame([[1, 2.0, "x"], [2, 3.0, "y"]], columns=["a", "b", "c"])
|
||||
dtable = dt.Frame(df)
|
||||
with pytest.raises(ValueError):
|
||||
xgb.DMatrix(dtable)
|
||||
|
||||
df = pd.DataFrame({'A=1': [1, 2, 3], 'A=2': [4, 5, 6]})
|
||||
df = pd.DataFrame({"A=1": [1, 2, 3], "A=2": [4, 5, 6]})
|
||||
dtable = dt.Frame(df)
|
||||
dm = xgb.DMatrix(dtable)
|
||||
assert dm.feature_names == ['A=1', 'A=2']
|
||||
assert dm.feature_types == ['int', 'int']
|
||||
assert dm.feature_names == ["A=1", "A=2"]
|
||||
assert dm.feature_types == ["int", "int"]
|
||||
assert dm.num_row() == 3
|
||||
assert dm.num_col() == 2
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user