parent
6bc9747df5
commit
70ce5216b5
@ -160,6 +160,7 @@ def _is_pandas_df(data):
|
||||
return False
|
||||
return isinstance(data, pd.DataFrame)
|
||||
|
||||
|
||||
def _is_modin_df(data):
|
||||
try:
|
||||
import modin.pandas as pd
|
||||
@ -188,11 +189,11 @@ def _transform_pandas_df(data, enable_categorical,
|
||||
feature_names=None, feature_types=None,
|
||||
meta=None, meta_type=None):
|
||||
from pandas import MultiIndex, Int64Index
|
||||
from pandas.api.types import is_sparse, is_categorical
|
||||
from pandas.api.types import is_sparse, is_categorical_dtype
|
||||
|
||||
data_dtypes = data.dtypes
|
||||
if not all(dtype.name in _pandas_dtype_mapper or is_sparse(dtype) or
|
||||
(is_categorical(dtype) and enable_categorical)
|
||||
(is_categorical_dtype(dtype) and enable_categorical)
|
||||
for dtype in data_dtypes):
|
||||
bad_fields = [
|
||||
str(data.columns[i]) for i, dtype in enumerate(data_dtypes)
|
||||
@ -220,7 +221,7 @@ def _transform_pandas_df(data, enable_categorical,
|
||||
if is_sparse(dtype):
|
||||
feature_types.append(_pandas_dtype_mapper[
|
||||
dtype.subtype.name])
|
||||
elif is_categorical(dtype) and enable_categorical:
|
||||
elif is_categorical_dtype(dtype) and enable_categorical:
|
||||
feature_types.append('categorical')
|
||||
else:
|
||||
feature_types.append(_pandas_dtype_mapper[dtype.name])
|
||||
|
||||
@ -131,42 +131,50 @@ struct IsCatOp {
|
||||
void RemoveDuplicatedCategories(
|
||||
int32_t device, MetaInfo const &info, Span<bst_row_t> d_cuts_ptr,
|
||||
dh::device_vector<Entry> *p_sorted_entries,
|
||||
dh::caching_device_vector<size_t> const &column_sizes_scan) {
|
||||
dh::caching_device_vector<size_t>* p_column_sizes_scan) {
|
||||
auto d_feature_types = info.feature_types.ConstDeviceSpan();
|
||||
auto& column_sizes_scan = *p_column_sizes_scan;
|
||||
if (!info.feature_types.Empty() &&
|
||||
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
|
||||
IsCatOp{})) {
|
||||
auto& sorted_entries = *p_sorted_entries;
|
||||
// Removing duplicated entries in categorical features.
|
||||
dh::caching_device_vector<size_t> new_column_scan(column_sizes_scan.size());
|
||||
dh::SegmentedUnique(column_sizes_scan.data().get(),
|
||||
column_sizes_scan.data().get() +
|
||||
column_sizes_scan.size(),
|
||||
sorted_entries.begin(), sorted_entries.end(),
|
||||
new_column_scan.data().get(), sorted_entries.begin(),
|
||||
[=] __device__(Entry const &l, Entry const &r) {
|
||||
if (l.index == r.index) {
|
||||
if (IsCat(d_feature_types, l.index)) {
|
||||
return l.fvalue == r.fvalue;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
});
|
||||
dh::SegmentedUnique(
|
||||
column_sizes_scan.data().get(),
|
||||
column_sizes_scan.data().get() + column_sizes_scan.size(),
|
||||
sorted_entries.begin(), sorted_entries.end(),
|
||||
new_column_scan.data().get(), sorted_entries.begin(),
|
||||
[=] __device__(Entry const &l, Entry const &r) {
|
||||
if (l.index == r.index) {
|
||||
if (IsCat(d_feature_types, l.index)) {
|
||||
return l.fvalue == r.fvalue;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
// Renew the column scan and cut scan based on categorical data.
|
||||
auto d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan);
|
||||
dh::caching_device_vector<SketchContainer::OffsetT> new_cuts_size(
|
||||
info.num_col_ + 1);
|
||||
auto d_new_cuts_size = dh::ToSpan(new_cuts_size);
|
||||
auto d_new_columns_ptr = dh::ToSpan(new_column_scan);
|
||||
CHECK_EQ(new_column_scan.size(), new_cuts_size.size());
|
||||
dh::LaunchN(device, new_column_scan.size() - 1, [=] __device__(size_t idx) {
|
||||
dh::LaunchN(device, new_column_scan.size(), [=] __device__(size_t idx) {
|
||||
d_old_column_sizes_scan[idx] = d_new_columns_ptr[idx];
|
||||
if (idx == d_new_columns_ptr.size() - 1) {
|
||||
return;
|
||||
}
|
||||
if (IsCat(d_feature_types, idx)) {
|
||||
// Cut size is the same as number of categories in input.
|
||||
d_new_cuts_size[idx] =
|
||||
d_new_columns_ptr[idx + 1] - d_new_columns_ptr[idx];
|
||||
} else {
|
||||
d_new_cuts_size[idx] = d_cuts_ptr[idx] - d_cuts_ptr[idx];
|
||||
}
|
||||
});
|
||||
// Turn size into ptr.
|
||||
thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(),
|
||||
new_cuts_size.cend(), d_cuts_ptr.data());
|
||||
}
|
||||
@ -197,7 +205,8 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
|
||||
&cuts_ptr, &column_sizes_scan);
|
||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries,
|
||||
column_sizes_scan);
|
||||
&column_sizes_scan);
|
||||
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
|
||||
|
||||
|
||||
@ -801,7 +801,8 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
|
||||
size_t size = max_cat == std::numeric_limits<bst_cat_t>::min()
|
||||
? 0
|
||||
: common::KCatBitField::ComputeStorageSize(max_cat);
|
||||
std::vector<uint32_t> cat_bits_storage(size);
|
||||
size = size == 0 ? 1 : size;
|
||||
std::vector<uint32_t> cat_bits_storage(size, 0);
|
||||
common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)};
|
||||
for (auto j = j_begin; j < j_end; ++j) {
|
||||
cat_bits.Set(common::AsCat(get<Integer const>(categories[j])));
|
||||
@ -818,7 +819,7 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
|
||||
if (cnt == categories_nodes.size()) {
|
||||
last_cat_node = -1;
|
||||
} else {
|
||||
last_cat_node = get<Integer const>(categories_nodes[++cnt]);
|
||||
last_cat_node = get<Integer const>(categories_nodes[cnt]);
|
||||
}
|
||||
} else {
|
||||
split_categories_segments_[nidx].beg = categories.size();
|
||||
|
||||
@ -41,6 +41,52 @@ class TestGPUUpdaters:
|
||||
note(result)
|
||||
assert tm.non_increasing(result['train'][dataset.metric])
|
||||
|
||||
def run_categorical_basic(self, cat, onehot, label, rounds):
|
||||
by_etl_results = {}
|
||||
by_builtin_results = {}
|
||||
|
||||
parameters = {'tree_method': 'gpu_hist',
|
||||
'predictor': 'gpu_predictor',
|
||||
'enable_experimental_json_serialization': True}
|
||||
|
||||
m = xgb.DMatrix(onehot, label, enable_categorical=True)
|
||||
xgb.train(parameters, m,
|
||||
num_boost_round=rounds,
|
||||
evals=[(m, 'Train')], evals_result=by_etl_results)
|
||||
|
||||
m = xgb.DMatrix(cat, label, enable_categorical=True)
|
||||
xgb.train(parameters, m,
|
||||
num_boost_round=rounds,
|
||||
evals=[(m, 'Train')], evals_result=by_builtin_results)
|
||||
np.testing.assert_allclose(
|
||||
np.array(by_etl_results['Train']['rmse']),
|
||||
np.array(by_builtin_results['Train']['rmse']),
|
||||
rtol=1e-3)
|
||||
assert tm.non_increasing(by_builtin_results['Train']['rmse'])
|
||||
|
||||
@given(strategies.integers(10, 400), strategies.integers(5, 10),
|
||||
strategies.integers(1, 5), strategies.integers(4, 8))
|
||||
@settings(deadline=None)
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_categorical(self, rows, cols, rounds, cats):
|
||||
import pandas as pd
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
pd_dict = {}
|
||||
for i in range(cols):
|
||||
c = rng.randint(low=0, high=cats+1, size=rows)
|
||||
pd_dict[str(i)] = pd.Series(c, dtype=np.int64)
|
||||
|
||||
df = pd.DataFrame(pd_dict)
|
||||
label = df.iloc[:, 0]
|
||||
for i in range(0, cols-1):
|
||||
label += df.iloc[:, i]
|
||||
label += 1
|
||||
df = df.astype('category')
|
||||
x = pd.get_dummies(df)
|
||||
|
||||
self.run_categorical_basic(df, x, label, rounds)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@given(parameter_strategy, strategies.integers(1, 20),
|
||||
tm.dataset_strategy)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user