From 886486a5193fc00309df950d19df57cb19b158d7 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 16 Dec 2020 14:23:28 +0800 Subject: [PATCH] Support categorical data in GPU weighted sketching. (#6508) --- src/common/hist_util.cu | 11 +++++++---- tests/cpp/common/test_hist_util.cu | 16 ++++++++++++++-- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 4d5ecf287..c75d5c2ef 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -220,10 +220,12 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page, } void ProcessWeightedBatch(int device, const SparsePage& page, - Span weights, size_t begin, size_t end, + MetaInfo const& info, size_t begin, size_t end, SketchContainer* sketch_container, int num_cuts_per_feature, size_t num_columns, bool is_ranking, Span d_group_ptr) { + auto weights = info.weights_.ConstDeviceSpan(); + dh::XGBCachingDeviceAllocator alloc; const auto& host_data = page.data.ConstHostVector(); dh::device_vector sorted_entries(host_data.begin() + begin, @@ -267,9 +269,10 @@ void ProcessWeightedBatch(int device, const SparsePage& page, batch_it, dummy_is_valid, 0, sorted_entries.size(), &cuts_ptr, &column_sizes_scan); - - auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); auto d_cuts_ptr = cuts_ptr.DeviceSpan(); + detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, + &column_sizes_scan); + auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); // Extract cuts sketch_container->Push(dh::ToSpan(sorted_entries), @@ -309,7 +312,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, dh::caching_device_vector groups(info.group_ptr_.cbegin(), info.group_ptr_.cend()); ProcessWeightedBatch( - device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end, + device, batch, dmat->Info(), begin, end, &sketch_container, num_cuts_per_feature, dmat->Info().num_col_, diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 332ec9233..65c51c5b3 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -122,10 +122,21 @@ TEST(HistUtil, DeviceSketchCategoricalAsNumeric) { } } -void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins) { +void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins, bool weighted) { auto x = GenerateRandomCategoricalSingleColumn(n, num_categories); auto dmat = GetDMatrixFromData(x, n, 1); dmat->Info().feature_types.HostVector().push_back(FeatureType::kCategorical); + + if (weighted) { + std::vector weights(n, 0); + SimpleLCG lcg; + SimpleRealUniformDistribution dist(0, 1); + for (auto& v : weights) { + v = dist(&lcg); + } + dmat->Info().weights_.HostVector() = weights; + } + ASSERT_EQ(dmat->Info().feature_types.Size(), 1); auto cuts = DeviceSketch(0, dmat.get(), num_bins); std::sort(x.begin(), x.end()); @@ -146,7 +157,8 @@ void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins) { } TEST(HistUtil, DeviceSketchCategoricalFeatures) { - TestCategoricalSketch(1000, 256, 32); + TestCategoricalSketch(1000, 256, 32, false); + TestCategoricalSketch(1000, 256, 32, true); } TEST(HistUtil, DeviceSketchMultipleColumns) {