Support categorical data in GPU weighted sketching. (#6508)

This commit is contained in:
Jiaming Yuan 2020-12-16 14:23:28 +08:00 committed by GitHub
parent 5c8ccf4455
commit 886486a519
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 6 deletions

View File

@ -220,10 +220,12 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
}
void ProcessWeightedBatch(int device, const SparsePage& page,
Span<const float> weights, size_t begin, size_t end,
MetaInfo const& info, size_t begin, size_t end,
SketchContainer* sketch_container, int num_cuts_per_feature,
size_t num_columns,
bool is_ranking, Span<bst_group_t const> d_group_ptr) {
auto weights = info.weights_.ConstDeviceSpan();
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
@ -267,9 +269,10 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
batch_it, dummy_is_valid,
0, sorted_entries.size(),
&cuts_ptr, &column_sizes_scan);
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries,
&column_sizes_scan);
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
// Extract cuts
sketch_container->Push(dh::ToSpan(sorted_entries),
@ -309,7 +312,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
info.group_ptr_.cend());
ProcessWeightedBatch(
device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end,
device, batch, dmat->Info(), begin, end,
&sketch_container,
num_cuts_per_feature,
dmat->Info().num_col_,

View File

@ -122,10 +122,21 @@ TEST(HistUtil, DeviceSketchCategoricalAsNumeric) {
}
}
void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins) {
void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins, bool weighted) {
auto x = GenerateRandomCategoricalSingleColumn(n, num_categories);
auto dmat = GetDMatrixFromData(x, n, 1);
dmat->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
if (weighted) {
std::vector<float> weights(n, 0);
SimpleLCG lcg;
SimpleRealUniformDistribution<float> dist(0, 1);
for (auto& v : weights) {
v = dist(&lcg);
}
dmat->Info().weights_.HostVector() = weights;
}
ASSERT_EQ(dmat->Info().feature_types.Size(), 1);
auto cuts = DeviceSketch(0, dmat.get(), num_bins);
std::sort(x.begin(), x.end());
@ -146,7 +157,8 @@ void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins) {
}
TEST(HistUtil, DeviceSketchCategoricalFeatures) {
TestCategoricalSketch(1000, 256, 32);
TestCategoricalSketch(1000, 256, 32, false);
TestCategoricalSketch(1000, 256, 32, true);
}
TEST(HistUtil, DeviceSketchMultipleColumns) {