Take datatable as row major input. (#8472)
* Take datatable as row major input. Try to avoid a transform with dense table.
This commit is contained in:
parent
284dcf8d22
commit
e07245f110
@ -44,8 +44,9 @@ _matrix_meta = {"base_margin", "label"}
|
|||||||
def _warn_unused_missing(data: DataType, missing: Optional[FloatCompatible]) -> None:
|
def _warn_unused_missing(data: DataType, missing: Optional[FloatCompatible]) -> None:
|
||||||
if (missing is not None) and (not np.isnan(missing)):
|
if (missing is not None) and (not np.isnan(missing)):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'`missing` is not used for current input data type:' +
|
"`missing` is not used for current input data type:" + str(type(data)),
|
||||||
str(type(data)), UserWarning)
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _check_complex(data: DataType) -> None:
|
def _check_complex(data: DataType) -> None:
|
||||||
@ -459,12 +460,9 @@ def _from_pandas_series(
|
|||||||
|
|
||||||
|
|
||||||
def _is_dt_df(data: DataType) -> bool:
|
def _is_dt_df(data: DataType) -> bool:
|
||||||
return lazy_isinstance(data, 'datatable', 'Frame') or \
|
return lazy_isinstance(data, "datatable", "Frame") or lazy_isinstance(
|
||||||
lazy_isinstance(data, 'datatable', 'DataTable')
|
data, "datatable", "DataTable"
|
||||||
|
)
|
||||||
|
|
||||||
_dt_type_mapper = {'bool': 'bool', 'int': 'int', 'real': 'float'}
|
|
||||||
_dt_type_mapper2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
|
|
||||||
|
|
||||||
|
|
||||||
def _transform_dt_df(
|
def _transform_dt_df(
|
||||||
@ -475,8 +473,10 @@ def _transform_dt_df(
|
|||||||
meta_type: Optional[NumpyDType] = None,
|
meta_type: Optional[NumpyDType] = None,
|
||||||
) -> Tuple[np.ndarray, Optional[FeatureNames], Optional[FeatureTypes]]:
|
) -> Tuple[np.ndarray, Optional[FeatureNames], Optional[FeatureTypes]]:
|
||||||
"""Validate feature names and types if data table"""
|
"""Validate feature names and types if data table"""
|
||||||
|
_dt_type_mapper = {"bool": "bool", "int": "int", "real": "float"}
|
||||||
|
_dt_type_mapper2 = {"bool": "i", "int": "int", "real": "float"}
|
||||||
if meta and data.shape[1] > 1:
|
if meta and data.shape[1] > 1:
|
||||||
raise ValueError('DataTable for meta info cannot have multiple columns')
|
raise ValueError("DataTable for meta info cannot have multiple columns")
|
||||||
if meta:
|
if meta:
|
||||||
meta_type = "float" if meta_type is None else meta_type
|
meta_type = "float" if meta_type is None else meta_type
|
||||||
# below requires new dt version
|
# below requires new dt version
|
||||||
@ -485,23 +485,23 @@ def _transform_dt_df(
|
|||||||
return data, None, None
|
return data, None, None
|
||||||
|
|
||||||
data_types_names = tuple(lt.name for lt in data.ltypes)
|
data_types_names = tuple(lt.name for lt in data.ltypes)
|
||||||
bad_fields = [data.names[i]
|
bad_fields = [
|
||||||
for i, type_name in enumerate(data_types_names)
|
data.names[i]
|
||||||
if type_name not in _dt_type_mapper]
|
for i, type_name in enumerate(data_types_names)
|
||||||
|
if type_name not in _dt_type_mapper
|
||||||
|
]
|
||||||
if bad_fields:
|
if bad_fields:
|
||||||
msg = """DataFrame.types for data must be int, float or bool.
|
msg = """DataFrame.types for data must be int, float or bool.
|
||||||
Did not expect the data types in fields """
|
Did not expect the data types in fields """
|
||||||
raise ValueError(msg + ', '.join(bad_fields))
|
raise ValueError(msg + ", ".join(bad_fields))
|
||||||
|
|
||||||
if feature_names is None and meta is None:
|
if feature_names is None and meta is None:
|
||||||
feature_names = data.names
|
feature_names = data.names
|
||||||
|
|
||||||
# always return stypes for dt ingestion
|
# always return stypes for dt ingestion
|
||||||
if feature_types is not None:
|
if feature_types is not None:
|
||||||
raise ValueError(
|
raise ValueError("DataTable has own feature types, cannot pass them in.")
|
||||||
'DataTable has own feature types, cannot pass them in.')
|
feature_types = np.vectorize(_dt_type_mapper2.get)(data_types_names).tolist()
|
||||||
feature_types = np.vectorize(_dt_type_mapper2.get)(
|
|
||||||
data_types_names).tolist()
|
|
||||||
|
|
||||||
return data, feature_names, feature_types
|
return data, feature_names, feature_types
|
||||||
|
|
||||||
@ -517,7 +517,8 @@ def _from_dt_df(
|
|||||||
if enable_categorical:
|
if enable_categorical:
|
||||||
raise ValueError("categorical data in datatable is not supported yet.")
|
raise ValueError("categorical data in datatable is not supported yet.")
|
||||||
data, feature_names, feature_types = _transform_dt_df(
|
data, feature_names, feature_types = _transform_dt_df(
|
||||||
data, feature_names, feature_types, None, None)
|
data, feature_names, feature_types, None, None
|
||||||
|
)
|
||||||
|
|
||||||
ptrs = (ctypes.c_void_p * data.ncols)()
|
ptrs = (ctypes.c_void_p * data.ncols)()
|
||||||
if hasattr(data, "internal") and hasattr(data.internal, "column"):
|
if hasattr(data, "internal") and hasattr(data.internal, "column"):
|
||||||
@ -531,6 +532,7 @@ def _from_dt_df(
|
|||||||
from datatable.internal import (
|
from datatable.internal import (
|
||||||
frame_column_data_r, # pylint: disable=no-name-in-module
|
frame_column_data_r, # pylint: disable=no-name-in-module
|
||||||
)
|
)
|
||||||
|
|
||||||
for icol in range(data.ncols):
|
for icol in range(data.ncols):
|
||||||
ptrs[icol] = frame_column_data_r(data, icol)
|
ptrs[icol] = frame_column_data_r(data, icol)
|
||||||
|
|
||||||
@ -538,16 +540,21 @@ def _from_dt_df(
|
|||||||
feature_type_strings = (ctypes.c_char_p * data.ncols)()
|
feature_type_strings = (ctypes.c_char_p * data.ncols)()
|
||||||
for icol in range(data.ncols):
|
for icol in range(data.ncols):
|
||||||
feature_type_strings[icol] = ctypes.c_char_p(
|
feature_type_strings[icol] = ctypes.c_char_p(
|
||||||
data.stypes[icol].name.encode('utf-8'))
|
data.stypes[icol].name.encode("utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
_warn_unused_missing(data, missing)
|
_warn_unused_missing(data, missing)
|
||||||
handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
_check_call(_LIB.XGDMatrixCreateFromDT(
|
_check_call(
|
||||||
ptrs, feature_type_strings,
|
_LIB.XGDMatrixCreateFromDT(
|
||||||
c_bst_ulong(data.shape[0]),
|
ptrs,
|
||||||
c_bst_ulong(data.shape[1]),
|
feature_type_strings,
|
||||||
ctypes.byref(handle),
|
c_bst_ulong(data.shape[0]),
|
||||||
ctypes.c_int(nthread)))
|
c_bst_ulong(data.shape[1]),
|
||||||
|
ctypes.byref(handle),
|
||||||
|
ctypes.c_int(nthread),
|
||||||
|
)
|
||||||
|
)
|
||||||
return handle, feature_names, feature_types
|
return handle, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -473,16 +473,7 @@ class CSCAdapter : public detail::SingleBatchDataIter<CSCAdapterBatch> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
class DataTableAdapterBatch : public detail::NoMetaInfo {
|
class DataTableAdapterBatch : public detail::NoMetaInfo {
|
||||||
public:
|
enum class DTType : std::uint8_t {
|
||||||
DataTableAdapterBatch(void** data, const char** feature_stypes,
|
|
||||||
size_t num_rows, size_t num_features)
|
|
||||||
: data_(data),
|
|
||||||
feature_stypes_(feature_stypes),
|
|
||||||
num_features_(num_features),
|
|
||||||
num_rows_(num_rows) {}
|
|
||||||
|
|
||||||
private:
|
|
||||||
enum class DTType : uint8_t {
|
|
||||||
kFloat32 = 0,
|
kFloat32 = 0,
|
||||||
kFloat64 = 1,
|
kFloat64 = 1,
|
||||||
kBool8 = 2,
|
kBool8 = 2,
|
||||||
@ -493,7 +484,7 @@ class DataTableAdapterBatch : public detail::NoMetaInfo {
|
|||||||
kUnknown = 7
|
kUnknown = 7
|
||||||
};
|
};
|
||||||
|
|
||||||
DTType DTGetType(std::string type_string) const {
|
static DTType DTGetType(std::string type_string) {
|
||||||
if (type_string == "float32") {
|
if (type_string == "float32") {
|
||||||
return DTType::kFloat32;
|
return DTType::kFloat32;
|
||||||
} else if (type_string == "float64") {
|
} else if (type_string == "float64") {
|
||||||
@ -514,8 +505,23 @@ class DataTableAdapterBatch : public detail::NoMetaInfo {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
DataTableAdapterBatch(void const* const* const data, char const* const* feature_stypes,
|
||||||
|
std::size_t num_rows, std::size_t num_features)
|
||||||
|
: data_(data), num_rows_(num_rows) {
|
||||||
|
CHECK(feature_types_.empty());
|
||||||
|
std::transform(feature_stypes, feature_stypes + num_features,
|
||||||
|
std::back_inserter(feature_types_),
|
||||||
|
[](char const* stype) { return DTGetType(stype); });
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
class Line {
|
class Line {
|
||||||
float DTGetValue(const void* column, DTType dt_type, size_t ridx) const {
|
std::size_t row_idx_;
|
||||||
|
void const* const* const data_;
|
||||||
|
std::vector<DTType> const& feature_types_;
|
||||||
|
|
||||||
|
float DTGetValue(void const* column, DTType dt_type, std::size_t ridx) const {
|
||||||
float missing = std::numeric_limits<float>::quiet_NaN();
|
float missing = std::numeric_limits<float>::quiet_NaN();
|
||||||
switch (dt_type) {
|
switch (dt_type) {
|
||||||
case DTType::kFloat32: {
|
case DTType::kFloat32: {
|
||||||
@ -544,8 +550,7 @@ class DataTableAdapterBatch : public detail::NoMetaInfo {
|
|||||||
}
|
}
|
||||||
case DTType::kInt64: {
|
case DTType::kInt64: {
|
||||||
int64_t val = reinterpret_cast<const int64_t*>(column)[ridx];
|
int64_t val = reinterpret_cast<const int64_t*>(column)[ridx];
|
||||||
return val != -9223372036854775807 - 1 ? static_cast<float>(val)
|
return val != -9223372036854775807 - 1 ? static_cast<float>(val) : missing;
|
||||||
: missing;
|
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
LOG(FATAL) << "Unknown data table type.";
|
LOG(FATAL) << "Unknown data table type.";
|
||||||
@ -555,51 +560,41 @@ class DataTableAdapterBatch : public detail::NoMetaInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Line(DTType type, size_t size, size_t column_idx, const void* column)
|
Line(std::size_t ridx, void const* const* const data, std::vector<DTType> const& ft)
|
||||||
: type_(type), size_(size), column_idx_(column_idx), column_(column) {}
|
: row_idx_{ridx}, data_{data}, feature_types_{ft} {}
|
||||||
|
std::size_t Size() const { return feature_types_.size(); }
|
||||||
size_t Size() const { return size_; }
|
COOTuple GetElement(std::size_t idx) const {
|
||||||
COOTuple GetElement(size_t idx) const {
|
return COOTuple{row_idx_, idx, DTGetValue(data_[idx], feature_types_[idx], row_idx_)};
|
||||||
return COOTuple{idx, column_idx_, DTGetValue(column_, type_, idx)};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
DTType type_;
|
|
||||||
size_t size_;
|
|
||||||
size_t column_idx_;
|
|
||||||
const void* column_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
size_t Size() const { return num_features_; }
|
size_t Size() const { return num_rows_; }
|
||||||
const Line GetLine(size_t idx) const {
|
const Line GetLine(std::size_t ridx) const { return {ridx, data_, feature_types_}; }
|
||||||
return Line(DTGetType(feature_stypes_[idx]), num_rows_, idx, data_[idx]);
|
static constexpr bool kIsRowMajor = true;
|
||||||
}
|
|
||||||
static constexpr bool kIsRowMajor = false;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void** data_;
|
void const* const* const data_;
|
||||||
const char** feature_stypes_;
|
|
||||||
size_t num_features_;
|
std::vector<DTType> feature_types_;
|
||||||
size_t num_rows_;
|
std::size_t num_rows_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class DataTableAdapter
|
class DataTableAdapter : public detail::SingleBatchDataIter<DataTableAdapterBatch> {
|
||||||
: public detail::SingleBatchDataIter<DataTableAdapterBatch> {
|
|
||||||
public:
|
public:
|
||||||
DataTableAdapter(void** data, const char** feature_stypes, size_t num_rows,
|
DataTableAdapter(void** data, const char** feature_stypes, std::size_t num_rows,
|
||||||
size_t num_features)
|
std::size_t num_features)
|
||||||
: batch_(data, feature_stypes, num_rows, num_features),
|
: batch_(data, feature_stypes, num_rows, num_features),
|
||||||
num_rows_(num_rows),
|
num_rows_(num_rows),
|
||||||
num_columns_(num_features) {}
|
num_columns_(num_features) {}
|
||||||
const DataTableAdapterBatch& Value() const override { return batch_; }
|
const DataTableAdapterBatch& Value() const override { return batch_; }
|
||||||
size_t NumRows() const { return num_rows_; }
|
std::size_t NumRows() const { return num_rows_; }
|
||||||
size_t NumColumns() const { return num_columns_; }
|
std::size_t NumColumns() const { return num_columns_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DataTableAdapterBatch batch_;
|
DataTableAdapterBatch batch_;
|
||||||
size_t num_rows_;
|
std::size_t num_rows_;
|
||||||
size_t num_columns_;
|
std::size_t num_columns_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class FileAdapterBatch {
|
class FileAdapterBatch {
|
||||||
|
|||||||
@ -144,6 +144,7 @@ def main(args: argparse.Namespace) -> None:
|
|||||||
# tests
|
# tests
|
||||||
"tests/python/test_config.py",
|
"tests/python/test_config.py",
|
||||||
"tests/python/test_data_iterator.py",
|
"tests/python/test_data_iterator.py",
|
||||||
|
"tests/python/test_dt.py",
|
||||||
"tests/python/test_quantile_dmatrix.py",
|
"tests/python/test_quantile_dmatrix.py",
|
||||||
"tests/python-gpu/test_gpu_data_iterator.py",
|
"tests/python-gpu/test_gpu_data_iterator.py",
|
||||||
"tests/test_distributed/test_with_spark/",
|
"tests/test_distributed/test_with_spark/",
|
||||||
@ -194,6 +195,7 @@ def main(args: argparse.Namespace) -> None:
|
|||||||
"demo/guide-python/external_memory.py",
|
"demo/guide-python/external_memory.py",
|
||||||
"demo/guide-python/cat_in_the_dat.py",
|
"demo/guide-python/cat_in_the_dat.py",
|
||||||
# tests
|
# tests
|
||||||
|
"tests/python/test_dt.py",
|
||||||
"tests/python/test_data_iterator.py",
|
"tests/python/test_data_iterator.py",
|
||||||
"tests/python-gpu/test_gpu_data_iterator.py",
|
"tests/python-gpu/test_gpu_data_iterator.py",
|
||||||
"tests/test_distributed/test_with_spark/test_data.py",
|
"tests/test_distributed/test_with_spark/test_data.py",
|
||||||
|
|||||||
@ -2,52 +2,40 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import testing as tm
|
|
||||||
|
|
||||||
try:
|
dt = pytest.importorskip("datatable")
|
||||||
import datatable as dt
|
pd = pytest.importorskip("pandas")
|
||||||
import pandas as pd
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.skipif(
|
|
||||||
tm.no_dt()['condition'] or tm.no_pandas()['condition'],
|
|
||||||
reason=tm.no_dt()['reason'] + ' or ' + tm.no_pandas()['reason'])
|
|
||||||
|
|
||||||
|
|
||||||
class TestDataTable:
|
class TestDataTable:
|
||||||
|
def test_dt(self) -> None:
|
||||||
def test_dt(self):
|
df = pd.DataFrame([[1, 2.0, True], [2, 3.0, False]], columns=["a", "b", "c"])
|
||||||
df = pd.DataFrame([[1, 2., True], [2, 3., False]],
|
|
||||||
columns=['a', 'b', 'c'])
|
|
||||||
dtable = dt.Frame(df)
|
dtable = dt.Frame(df)
|
||||||
labels = dt.Frame([1, 2])
|
labels = dt.Frame([1, 2])
|
||||||
dm = xgb.DMatrix(dtable, label=labels)
|
dm = xgb.DMatrix(dtable, label=labels)
|
||||||
assert dm.feature_names == ['a', 'b', 'c']
|
assert dm.feature_names == ["a", "b", "c"]
|
||||||
assert dm.feature_types == ['int', 'float', 'i']
|
assert dm.feature_types == ["int", "float", "i"]
|
||||||
assert dm.num_row() == 2
|
assert dm.num_row() == 2
|
||||||
assert dm.num_col() == 3
|
assert dm.num_col() == 3
|
||||||
|
|
||||||
np.testing.assert_array_equal(np.array([1, 2]), dm.get_label())
|
np.testing.assert_array_equal(np.array([1, 2]), dm.get_label())
|
||||||
|
|
||||||
# overwrite feature_names
|
# overwrite feature_names
|
||||||
dm = xgb.DMatrix(dtable, label=pd.Series([1, 2]),
|
dm = xgb.DMatrix(dtable, label=pd.Series([1, 2]), feature_names=["x", "y", "z"])
|
||||||
feature_names=['x', 'y', 'z'])
|
assert dm.feature_names == ["x", "y", "z"]
|
||||||
assert dm.feature_names == ['x', 'y', 'z']
|
|
||||||
assert dm.num_row() == 2
|
assert dm.num_row() == 2
|
||||||
assert dm.num_col() == 3
|
assert dm.num_col() == 3
|
||||||
|
|
||||||
# incorrect dtypes
|
# incorrect dtypes
|
||||||
df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']],
|
df = pd.DataFrame([[1, 2.0, "x"], [2, 3.0, "y"]], columns=["a", "b", "c"])
|
||||||
columns=['a', 'b', 'c'])
|
|
||||||
dtable = dt.Frame(df)
|
dtable = dt.Frame(df)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
xgb.DMatrix(dtable)
|
xgb.DMatrix(dtable)
|
||||||
|
|
||||||
df = pd.DataFrame({'A=1': [1, 2, 3], 'A=2': [4, 5, 6]})
|
df = pd.DataFrame({"A=1": [1, 2, 3], "A=2": [4, 5, 6]})
|
||||||
dtable = dt.Frame(df)
|
dtable = dt.Frame(df)
|
||||||
dm = xgb.DMatrix(dtable)
|
dm = xgb.DMatrix(dtable)
|
||||||
assert dm.feature_names == ['A=1', 'A=2']
|
assert dm.feature_names == ["A=1", "A=2"]
|
||||||
assert dm.feature_types == ['int', 'int']
|
assert dm.feature_types == ["int", "int"]
|
||||||
assert dm.num_row() == 3
|
assert dm.num_row() == 3
|
||||||
assert dm.num_col() == 2
|
assert dm.num_col() == 2
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user