Remove unused weight from buffer for cat features. (#9341)

This commit is contained in:
Jiaming Yuan
2023-07-04 01:07:09 +08:00
committed by GitHub
parent 6155394a06
commit d0916849a6
3 changed files with 142 additions and 55 deletions

View File

@@ -127,55 +127,76 @@ void SortByWeight(dh::device_vector<float>* weights,
});
}
void RemoveDuplicatedCategories(
int32_t device, MetaInfo const &info, Span<bst_row_t> d_cuts_ptr,
dh::device_vector<Entry> *p_sorted_entries,
dh::caching_device_vector<size_t> *p_column_sizes_scan) {
void RemoveDuplicatedCategories(int32_t device, MetaInfo const& info, Span<bst_row_t> d_cuts_ptr,
dh::device_vector<Entry>* p_sorted_entries,
dh::device_vector<float>* p_sorted_weights,
dh::caching_device_vector<size_t>* p_column_sizes_scan) {
info.feature_types.SetDevice(device);
auto d_feature_types = info.feature_types.ConstDeviceSpan();
CHECK(!d_feature_types.empty());
auto &column_sizes_scan = *p_column_sizes_scan;
auto &sorted_entries = *p_sorted_entries;
auto& column_sizes_scan = *p_column_sizes_scan;
auto& sorted_entries = *p_sorted_entries;
// Removing duplicated entries in categorical features.
// We don't need to accumulate weight for duplicated entries as there's no weighted
// sketching for categorical features, the categories are the cut values.
dh::caching_device_vector<size_t> new_column_scan(column_sizes_scan.size());
dh::SegmentedUnique(column_sizes_scan.data().get(),
column_sizes_scan.data().get() + column_sizes_scan.size(),
sorted_entries.begin(), sorted_entries.end(),
new_column_scan.data().get(), sorted_entries.begin(),
[=] __device__(Entry const &l, Entry const &r) {
if (l.index == r.index) {
if (IsCat(d_feature_types, l.index)) {
return l.fvalue == r.fvalue;
}
}
return false;
});
std::size_t n_uniques{0};
if (p_sorted_weights) {
using Pair = thrust::tuple<Entry, float>;
auto d_sorted_entries = dh::ToSpan(sorted_entries);
auto d_sorted_weights = dh::ToSpan(*p_sorted_weights);
auto val_in_it = thrust::make_zip_iterator(d_sorted_entries.data(), d_sorted_weights.data());
auto val_out_it = thrust::make_zip_iterator(d_sorted_entries.data(), d_sorted_weights.data());
n_uniques = dh::SegmentedUnique(
column_sizes_scan.data().get(), column_sizes_scan.data().get() + column_sizes_scan.size(),
val_in_it, val_in_it + sorted_entries.size(), new_column_scan.data().get(), val_out_it,
[=] __device__(Pair const& l, Pair const& r) {
Entry const& le = thrust::get<0>(l);
Entry const& re = thrust::get<0>(r);
if (le.index == re.index && IsCat(d_feature_types, le.index)) {
return le.fvalue == re.fvalue;
}
return false;
});
p_sorted_weights->resize(n_uniques);
} else {
n_uniques = dh::SegmentedUnique(
column_sizes_scan.data().get(), column_sizes_scan.data().get() + column_sizes_scan.size(),
sorted_entries.begin(), sorted_entries.end(), new_column_scan.data().get(),
sorted_entries.begin(), [=] __device__(Entry const& l, Entry const& r) {
if (l.index == r.index) {
if (IsCat(d_feature_types, l.index)) {
return l.fvalue == r.fvalue;
}
}
return false;
});
}
sorted_entries.resize(n_uniques);
// Renew the column scan and cut scan based on categorical data.
auto d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan);
dh::caching_device_vector<SketchContainer::OffsetT> new_cuts_size(
info.num_col_ + 1);
dh::caching_device_vector<SketchContainer::OffsetT> new_cuts_size(info.num_col_ + 1);
CHECK_EQ(new_column_scan.size(), new_cuts_size.size());
dh::LaunchN(
new_column_scan.size(),
[=, d_new_cuts_size = dh::ToSpan(new_cuts_size),
d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan),
d_new_columns_ptr = dh::ToSpan(new_column_scan)] __device__(size_t idx) {
d_old_column_sizes_scan[idx] = d_new_columns_ptr[idx];
if (idx == d_new_columns_ptr.size() - 1) {
return;
}
if (IsCat(d_feature_types, idx)) {
// Cut size is the same as number of categories in input.
d_new_cuts_size[idx] =
d_new_columns_ptr[idx + 1] - d_new_columns_ptr[idx];
} else {
d_new_cuts_size[idx] = d_cuts_ptr[idx + 1] - d_cuts_ptr[idx];
}
});
dh::LaunchN(new_column_scan.size(),
[=, d_new_cuts_size = dh::ToSpan(new_cuts_size),
d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan),
d_new_columns_ptr = dh::ToSpan(new_column_scan)] __device__(size_t idx) {
d_old_column_sizes_scan[idx] = d_new_columns_ptr[idx];
if (idx == d_new_columns_ptr.size() - 1) {
return;
}
if (IsCat(d_feature_types, idx)) {
// Cut size is the same as number of categories in input.
d_new_cuts_size[idx] = d_new_columns_ptr[idx + 1] - d_new_columns_ptr[idx];
} else {
d_new_cuts_size[idx] = d_cuts_ptr[idx + 1] - d_cuts_ptr[idx];
}
});
// Turn size into ptr.
thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(),
new_cuts_size.cend(), d_cuts_ptr.data());
thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(), new_cuts_size.cend(),
d_cuts_ptr.data());
}
} // namespace detail
@@ -209,8 +230,8 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
if (sketch_container->HasCategorical()) {
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,
&sorted_entries, &column_sizes_scan);
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, nullptr,
&column_sizes_scan);
}
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
@@ -276,8 +297,8 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
&column_sizes_scan);
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
if (sketch_container->HasCategorical()) {
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,
&sorted_entries, &column_sizes_scan);
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, &temp_weights,
&column_sizes_scan);
}
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();

View File

@@ -240,10 +240,10 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, Ran
void SortByWeight(dh::device_vector<float>* weights,
dh::device_vector<Entry>* sorted_entries);
void RemoveDuplicatedCategories(
int32_t device, MetaInfo const &info, Span<bst_row_t> d_cuts_ptr,
dh::device_vector<Entry> *p_sorted_entries,
dh::caching_device_vector<size_t> *p_column_sizes_scan);
void RemoveDuplicatedCategories(int32_t device, MetaInfo const& info, Span<bst_row_t> d_cuts_ptr,
dh::device_vector<Entry>* p_sorted_entries,
dh::device_vector<float>* p_sorted_weights,
dh::caching_device_vector<size_t>* p_column_sizes_scan);
} // namespace detail
// Compute sketch on DMatrix.
@@ -275,8 +275,8 @@ void ProcessSlidingWindow(AdapterBatch const &batch, MetaInfo const &info,
if (sketch_container->HasCategorical()) {
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,
&sorted_entries, &column_sizes_scan);
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, nullptr,
&column_sizes_scan);
}
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
@@ -354,8 +354,8 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
if (sketch_container->HasCategorical()) {
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,
&sorted_entries, &column_sizes_scan);
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, &temp_weights,
&column_sizes_scan);
}
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();