@@ -131,42 +131,50 @@ struct IsCatOp {
|
||||
void RemoveDuplicatedCategories(
|
||||
int32_t device, MetaInfo const &info, Span<bst_row_t> d_cuts_ptr,
|
||||
dh::device_vector<Entry> *p_sorted_entries,
|
||||
dh::caching_device_vector<size_t> const &column_sizes_scan) {
|
||||
dh::caching_device_vector<size_t>* p_column_sizes_scan) {
|
||||
auto d_feature_types = info.feature_types.ConstDeviceSpan();
|
||||
auto& column_sizes_scan = *p_column_sizes_scan;
|
||||
if (!info.feature_types.Empty() &&
|
||||
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
|
||||
IsCatOp{})) {
|
||||
auto& sorted_entries = *p_sorted_entries;
|
||||
// Removing duplicated entries in categorical features.
|
||||
dh::caching_device_vector<size_t> new_column_scan(column_sizes_scan.size());
|
||||
dh::SegmentedUnique(column_sizes_scan.data().get(),
|
||||
column_sizes_scan.data().get() +
|
||||
column_sizes_scan.size(),
|
||||
sorted_entries.begin(), sorted_entries.end(),
|
||||
new_column_scan.data().get(), sorted_entries.begin(),
|
||||
[=] __device__(Entry const &l, Entry const &r) {
|
||||
if (l.index == r.index) {
|
||||
if (IsCat(d_feature_types, l.index)) {
|
||||
return l.fvalue == r.fvalue;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
});
|
||||
dh::SegmentedUnique(
|
||||
column_sizes_scan.data().get(),
|
||||
column_sizes_scan.data().get() + column_sizes_scan.size(),
|
||||
sorted_entries.begin(), sorted_entries.end(),
|
||||
new_column_scan.data().get(), sorted_entries.begin(),
|
||||
[=] __device__(Entry const &l, Entry const &r) {
|
||||
if (l.index == r.index) {
|
||||
if (IsCat(d_feature_types, l.index)) {
|
||||
return l.fvalue == r.fvalue;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
// Renew the column scan and cut scan based on categorical data.
|
||||
auto d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan);
|
||||
dh::caching_device_vector<SketchContainer::OffsetT> new_cuts_size(
|
||||
info.num_col_ + 1);
|
||||
auto d_new_cuts_size = dh::ToSpan(new_cuts_size);
|
||||
auto d_new_columns_ptr = dh::ToSpan(new_column_scan);
|
||||
CHECK_EQ(new_column_scan.size(), new_cuts_size.size());
|
||||
dh::LaunchN(device, new_column_scan.size() - 1, [=] __device__(size_t idx) {
|
||||
dh::LaunchN(device, new_column_scan.size(), [=] __device__(size_t idx) {
|
||||
d_old_column_sizes_scan[idx] = d_new_columns_ptr[idx];
|
||||
if (idx == d_new_columns_ptr.size() - 1) {
|
||||
return;
|
||||
}
|
||||
if (IsCat(d_feature_types, idx)) {
|
||||
// Cut size is the same as number of categories in input.
|
||||
d_new_cuts_size[idx] =
|
||||
d_new_columns_ptr[idx + 1] - d_new_columns_ptr[idx];
|
||||
} else {
|
||||
d_new_cuts_size[idx] = d_cuts_ptr[idx] - d_cuts_ptr[idx];
|
||||
}
|
||||
});
|
||||
// Turn size into ptr.
|
||||
thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(),
|
||||
new_cuts_size.cend(), d_cuts_ptr.data());
|
||||
}
|
||||
@@ -197,7 +205,8 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
|
||||
&cuts_ptr, &column_sizes_scan);
|
||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries,
|
||||
column_sizes_scan);
|
||||
&column_sizes_scan);
|
||||
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
|
||||
|
||||
|
||||
@@ -801,7 +801,8 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
|
||||
size_t size = max_cat == std::numeric_limits<bst_cat_t>::min()
|
||||
? 0
|
||||
: common::KCatBitField::ComputeStorageSize(max_cat);
|
||||
std::vector<uint32_t> cat_bits_storage(size);
|
||||
size = size == 0 ? 1 : size;
|
||||
std::vector<uint32_t> cat_bits_storage(size, 0);
|
||||
common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)};
|
||||
for (auto j = j_begin; j < j_end; ++j) {
|
||||
cat_bits.Set(common::AsCat(get<Integer const>(categories[j])));
|
||||
@@ -818,7 +819,7 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
|
||||
if (cnt == categories_nodes.size()) {
|
||||
last_cat_node = -1;
|
||||
} else {
|
||||
last_cat_node = get<Integer const>(categories_nodes[++cnt]);
|
||||
last_cat_node = get<Integer const>(categories_nodes[cnt]);
|
||||
}
|
||||
} else {
|
||||
split_categories_segments_[nidx].beg = categories.size();
|
||||
|
||||
Reference in New Issue
Block a user