Remove unused weight from buffer for cat features. (#9341)
This commit is contained in:
parent
6155394a06
commit
d0916849a6
@ -127,22 +127,44 @@ void SortByWeight(dh::device_vector<float>* weights,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void RemoveDuplicatedCategories(
|
void RemoveDuplicatedCategories(int32_t device, MetaInfo const& info, Span<bst_row_t> d_cuts_ptr,
|
||||||
int32_t device, MetaInfo const &info, Span<bst_row_t> d_cuts_ptr,
|
dh::device_vector<Entry>* p_sorted_entries,
|
||||||
dh::device_vector<Entry> *p_sorted_entries,
|
dh::device_vector<float>* p_sorted_weights,
|
||||||
dh::caching_device_vector<size_t> *p_column_sizes_scan) {
|
dh::caching_device_vector<size_t>* p_column_sizes_scan) {
|
||||||
info.feature_types.SetDevice(device);
|
info.feature_types.SetDevice(device);
|
||||||
auto d_feature_types = info.feature_types.ConstDeviceSpan();
|
auto d_feature_types = info.feature_types.ConstDeviceSpan();
|
||||||
CHECK(!d_feature_types.empty());
|
CHECK(!d_feature_types.empty());
|
||||||
auto &column_sizes_scan = *p_column_sizes_scan;
|
auto& column_sizes_scan = *p_column_sizes_scan;
|
||||||
auto &sorted_entries = *p_sorted_entries;
|
auto& sorted_entries = *p_sorted_entries;
|
||||||
// Removing duplicated entries in categorical features.
|
// Removing duplicated entries in categorical features.
|
||||||
|
|
||||||
|
// We don't need to accumulate weight for duplicated entries as there's no weighted
|
||||||
|
// sketching for categorical features, the categories are the cut values.
|
||||||
dh::caching_device_vector<size_t> new_column_scan(column_sizes_scan.size());
|
dh::caching_device_vector<size_t> new_column_scan(column_sizes_scan.size());
|
||||||
dh::SegmentedUnique(column_sizes_scan.data().get(),
|
std::size_t n_uniques{0};
|
||||||
column_sizes_scan.data().get() + column_sizes_scan.size(),
|
if (p_sorted_weights) {
|
||||||
sorted_entries.begin(), sorted_entries.end(),
|
using Pair = thrust::tuple<Entry, float>;
|
||||||
new_column_scan.data().get(), sorted_entries.begin(),
|
auto d_sorted_entries = dh::ToSpan(sorted_entries);
|
||||||
[=] __device__(Entry const &l, Entry const &r) {
|
auto d_sorted_weights = dh::ToSpan(*p_sorted_weights);
|
||||||
|
auto val_in_it = thrust::make_zip_iterator(d_sorted_entries.data(), d_sorted_weights.data());
|
||||||
|
auto val_out_it = thrust::make_zip_iterator(d_sorted_entries.data(), d_sorted_weights.data());
|
||||||
|
n_uniques = dh::SegmentedUnique(
|
||||||
|
column_sizes_scan.data().get(), column_sizes_scan.data().get() + column_sizes_scan.size(),
|
||||||
|
val_in_it, val_in_it + sorted_entries.size(), new_column_scan.data().get(), val_out_it,
|
||||||
|
[=] __device__(Pair const& l, Pair const& r) {
|
||||||
|
Entry const& le = thrust::get<0>(l);
|
||||||
|
Entry const& re = thrust::get<0>(r);
|
||||||
|
if (le.index == re.index && IsCat(d_feature_types, le.index)) {
|
||||||
|
return le.fvalue == re.fvalue;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
p_sorted_weights->resize(n_uniques);
|
||||||
|
} else {
|
||||||
|
n_uniques = dh::SegmentedUnique(
|
||||||
|
column_sizes_scan.data().get(), column_sizes_scan.data().get() + column_sizes_scan.size(),
|
||||||
|
sorted_entries.begin(), sorted_entries.end(), new_column_scan.data().get(),
|
||||||
|
sorted_entries.begin(), [=] __device__(Entry const& l, Entry const& r) {
|
||||||
if (l.index == r.index) {
|
if (l.index == r.index) {
|
||||||
if (IsCat(d_feature_types, l.index)) {
|
if (IsCat(d_feature_types, l.index)) {
|
||||||
return l.fvalue == r.fvalue;
|
return l.fvalue == r.fvalue;
|
||||||
@ -150,14 +172,14 @@ void RemoveDuplicatedCategories(
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
sorted_entries.resize(n_uniques);
|
||||||
|
|
||||||
// Renew the column scan and cut scan based on categorical data.
|
// Renew the column scan and cut scan based on categorical data.
|
||||||
auto d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan);
|
auto d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan);
|
||||||
dh::caching_device_vector<SketchContainer::OffsetT> new_cuts_size(
|
dh::caching_device_vector<SketchContainer::OffsetT> new_cuts_size(info.num_col_ + 1);
|
||||||
info.num_col_ + 1);
|
|
||||||
CHECK_EQ(new_column_scan.size(), new_cuts_size.size());
|
CHECK_EQ(new_column_scan.size(), new_cuts_size.size());
|
||||||
dh::LaunchN(
|
dh::LaunchN(new_column_scan.size(),
|
||||||
new_column_scan.size(),
|
|
||||||
[=, d_new_cuts_size = dh::ToSpan(new_cuts_size),
|
[=, d_new_cuts_size = dh::ToSpan(new_cuts_size),
|
||||||
d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan),
|
d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan),
|
||||||
d_new_columns_ptr = dh::ToSpan(new_column_scan)] __device__(size_t idx) {
|
d_new_columns_ptr = dh::ToSpan(new_column_scan)] __device__(size_t idx) {
|
||||||
@ -167,15 +189,14 @@ void RemoveDuplicatedCategories(
|
|||||||
}
|
}
|
||||||
if (IsCat(d_feature_types, idx)) {
|
if (IsCat(d_feature_types, idx)) {
|
||||||
// Cut size is the same as number of categories in input.
|
// Cut size is the same as number of categories in input.
|
||||||
d_new_cuts_size[idx] =
|
d_new_cuts_size[idx] = d_new_columns_ptr[idx + 1] - d_new_columns_ptr[idx];
|
||||||
d_new_columns_ptr[idx + 1] - d_new_columns_ptr[idx];
|
|
||||||
} else {
|
} else {
|
||||||
d_new_cuts_size[idx] = d_cuts_ptr[idx + 1] - d_cuts_ptr[idx];
|
d_new_cuts_size[idx] = d_cuts_ptr[idx + 1] - d_cuts_ptr[idx];
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
// Turn size into ptr.
|
// Turn size into ptr.
|
||||||
thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(),
|
thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(), new_cuts_size.cend(),
|
||||||
new_cuts_size.cend(), d_cuts_ptr.data());
|
d_cuts_ptr.data());
|
||||||
}
|
}
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
@ -209,8 +230,8 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
|
|||||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||||
|
|
||||||
if (sketch_container->HasCategorical()) {
|
if (sketch_container->HasCategorical()) {
|
||||||
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,
|
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, nullptr,
|
||||||
&sorted_entries, &column_sizes_scan);
|
&column_sizes_scan);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||||
@ -276,8 +297,8 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
|||||||
&column_sizes_scan);
|
&column_sizes_scan);
|
||||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||||
if (sketch_container->HasCategorical()) {
|
if (sketch_container->HasCategorical()) {
|
||||||
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,
|
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, &temp_weights,
|
||||||
&sorted_entries, &column_sizes_scan);
|
&column_sizes_scan);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||||
|
|||||||
@ -240,10 +240,10 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, Ran
|
|||||||
void SortByWeight(dh::device_vector<float>* weights,
|
void SortByWeight(dh::device_vector<float>* weights,
|
||||||
dh::device_vector<Entry>* sorted_entries);
|
dh::device_vector<Entry>* sorted_entries);
|
||||||
|
|
||||||
void RemoveDuplicatedCategories(
|
void RemoveDuplicatedCategories(int32_t device, MetaInfo const& info, Span<bst_row_t> d_cuts_ptr,
|
||||||
int32_t device, MetaInfo const &info, Span<bst_row_t> d_cuts_ptr,
|
dh::device_vector<Entry>* p_sorted_entries,
|
||||||
dh::device_vector<Entry> *p_sorted_entries,
|
dh::device_vector<float>* p_sorted_weights,
|
||||||
dh::caching_device_vector<size_t> *p_column_sizes_scan);
|
dh::caching_device_vector<size_t>* p_column_sizes_scan);
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
// Compute sketch on DMatrix.
|
// Compute sketch on DMatrix.
|
||||||
@ -275,8 +275,8 @@ void ProcessSlidingWindow(AdapterBatch const &batch, MetaInfo const &info,
|
|||||||
|
|
||||||
if (sketch_container->HasCategorical()) {
|
if (sketch_container->HasCategorical()) {
|
||||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||||
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,
|
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, nullptr,
|
||||||
&sorted_entries, &column_sizes_scan);
|
&column_sizes_scan);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||||
@ -354,8 +354,8 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
|||||||
|
|
||||||
if (sketch_container->HasCategorical()) {
|
if (sketch_container->HasCategorical()) {
|
||||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||||
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,
|
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, &temp_weights,
|
||||||
&sorted_entries, &column_sizes_scan);
|
&column_sizes_scan);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||||
|
|||||||
@ -143,11 +143,14 @@ TEST(HistUtil, DeviceSketchCategoricalFeatures) {
|
|||||||
|
|
||||||
void TestMixedSketch() {
|
void TestMixedSketch() {
|
||||||
size_t n_samples = 1000, n_features = 2, n_categories = 3;
|
size_t n_samples = 1000, n_features = 2, n_categories = 3;
|
||||||
|
bst_bin_t n_bins = 64;
|
||||||
|
|
||||||
std::vector<float> data(n_samples * n_features);
|
std::vector<float> data(n_samples * n_features);
|
||||||
SimpleLCG gen;
|
SimpleLCG gen;
|
||||||
SimpleRealUniformDistribution<float> cat_d{0.0f, static_cast<float>(n_categories)};
|
SimpleRealUniformDistribution<float> cat_d{0.0f, static_cast<float>(n_categories)};
|
||||||
SimpleRealUniformDistribution<float> num_d{0.0f, 3.0f};
|
SimpleRealUniformDistribution<float> num_d{0.0f, 3.0f};
|
||||||
for (size_t i = 0; i < n_samples * n_features; ++i) {
|
for (size_t i = 0; i < n_samples * n_features; ++i) {
|
||||||
|
// two features, row major. The first column is numeric and the second is categorical.
|
||||||
if (i % 2 == 0) {
|
if (i % 2 == 0) {
|
||||||
data[i] = std::floor(cat_d(&gen));
|
data[i] = std::floor(cat_d(&gen));
|
||||||
} else {
|
} else {
|
||||||
@ -159,12 +162,75 @@ void TestMixedSketch() {
|
|||||||
m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
||||||
m->Info().feature_types.HostVector().push_back(FeatureType::kNumerical);
|
m->Info().feature_types.HostVector().push_back(FeatureType::kNumerical);
|
||||||
|
|
||||||
auto cuts = DeviceSketch(0, m.get(), 64);
|
auto cuts = DeviceSketch(0, m.get(), n_bins);
|
||||||
ASSERT_EQ(cuts.Values().size(), 64 + n_categories);
|
ASSERT_EQ(cuts.Values().size(), n_bins + n_categories);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HistUtil, DeviceSketchMixedFeatures) {
|
TEST(HistUtil, DeviceSketchMixedFeatures) { TestMixedSketch(); }
|
||||||
TestMixedSketch();
|
|
||||||
|
TEST(HistUtil, RemoveDuplicatedCategories) {
|
||||||
|
bst_row_t n_samples = 512;
|
||||||
|
bst_feature_t n_features = 3;
|
||||||
|
bst_cat_t n_categories = 5;
|
||||||
|
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
|
SimpleLCG rng;
|
||||||
|
SimpleRealUniformDistribution<float> cat_d{0.0f, static_cast<float>(n_categories)};
|
||||||
|
|
||||||
|
dh::device_vector<Entry> sorted_entries(n_samples * n_features);
|
||||||
|
for (std::size_t i = 0; i < n_samples; ++i) {
|
||||||
|
for (bst_feature_t j = 0; j < n_features; ++j) {
|
||||||
|
float fvalue{0.0f};
|
||||||
|
// The second column is categorical
|
||||||
|
if (j == 1) {
|
||||||
|
fvalue = std::floor(cat_d(&rng));
|
||||||
|
} else {
|
||||||
|
fvalue = i;
|
||||||
|
}
|
||||||
|
sorted_entries[i * n_features + j] = Entry{j, fvalue};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MetaInfo info;
|
||||||
|
info.num_col_ = n_features;
|
||||||
|
info.num_row_ = n_samples;
|
||||||
|
info.feature_types.HostVector() = std::vector<FeatureType>{
|
||||||
|
FeatureType::kNumerical, FeatureType::kCategorical, FeatureType::kNumerical};
|
||||||
|
ASSERT_EQ(info.feature_types.Size(), n_features);
|
||||||
|
|
||||||
|
HostDeviceVector<bst_row_t> cuts_ptr{0, n_samples, n_samples * 2, n_samples * 3};
|
||||||
|
cuts_ptr.SetDevice(0);
|
||||||
|
|
||||||
|
dh::device_vector<float> weight(n_samples * n_features, 0);
|
||||||
|
dh::Iota(dh::ToSpan(weight));
|
||||||
|
|
||||||
|
dh::caching_device_vector<bst_row_t> columns_ptr(4);
|
||||||
|
for (std::size_t i = 0; i < columns_ptr.size(); ++i) {
|
||||||
|
columns_ptr[i] = i * n_samples;
|
||||||
|
}
|
||||||
|
// sort into column major
|
||||||
|
thrust::sort_by_key(sorted_entries.begin(), sorted_entries.end(), weight.begin(),
|
||||||
|
detail::EntryCompareOp());
|
||||||
|
|
||||||
|
detail::RemoveDuplicatedCategories(ctx.gpu_id, info, cuts_ptr.DeviceSpan(), &sorted_entries,
|
||||||
|
&weight, &columns_ptr);
|
||||||
|
|
||||||
|
auto const& h_cptr = cuts_ptr.ConstHostVector();
|
||||||
|
ASSERT_EQ(h_cptr.back(), n_samples * 2 + n_categories);
|
||||||
|
// check numerical
|
||||||
|
for (std::size_t i = 0; i < n_samples; ++i) {
|
||||||
|
ASSERT_EQ(weight[i], i * 3);
|
||||||
|
}
|
||||||
|
auto beg = n_samples + n_categories;
|
||||||
|
for (std::size_t i = 0; i < n_samples; ++i) {
|
||||||
|
ASSERT_EQ(weight[i + beg], i * 3 + 2);
|
||||||
|
}
|
||||||
|
// check categorical
|
||||||
|
beg = n_samples;
|
||||||
|
for (std::size_t i = 0; i < n_categories; ++i) {
|
||||||
|
// all from the second column
|
||||||
|
ASSERT_EQ(static_cast<bst_feature_t>(weight[i + beg]) % n_features, 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HistUtil, DeviceSketchMultipleColumns) {
|
TEST(HistUtil, DeviceSketchMultipleColumns) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user