Support categorical data in GPU weighted sketching. (#6508)
This commit is contained in:
parent
5c8ccf4455
commit
886486a519
@ -220,10 +220,12 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
|
||||
}
|
||||
|
||||
void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
Span<const float> weights, size_t begin, size_t end,
|
||||
MetaInfo const& info, size_t begin, size_t end,
|
||||
SketchContainer* sketch_container, int num_cuts_per_feature,
|
||||
size_t num_columns,
|
||||
bool is_ranking, Span<bst_group_t const> d_group_ptr) {
|
||||
auto weights = info.weights_.ConstDeviceSpan();
|
||||
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
@ -267,9 +269,10 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
batch_it, dummy_is_valid,
|
||||
0, sorted_entries.size(),
|
||||
&cuts_ptr, &column_sizes_scan);
|
||||
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries,
|
||||
&column_sizes_scan);
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
|
||||
// Extract cuts
|
||||
sketch_container->Push(dh::ToSpan(sorted_entries),
|
||||
@ -309,7 +312,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
|
||||
info.group_ptr_.cend());
|
||||
ProcessWeightedBatch(
|
||||
device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end,
|
||||
device, batch, dmat->Info(), begin, end,
|
||||
&sketch_container,
|
||||
num_cuts_per_feature,
|
||||
dmat->Info().num_col_,
|
||||
|
||||
@ -122,10 +122,21 @@ TEST(HistUtil, DeviceSketchCategoricalAsNumeric) {
|
||||
}
|
||||
}
|
||||
|
||||
void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins) {
|
||||
void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins, bool weighted) {
|
||||
auto x = GenerateRandomCategoricalSingleColumn(n, num_categories);
|
||||
auto dmat = GetDMatrixFromData(x, n, 1);
|
||||
dmat->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
||||
|
||||
if (weighted) {
|
||||
std::vector<float> weights(n, 0);
|
||||
SimpleLCG lcg;
|
||||
SimpleRealUniformDistribution<float> dist(0, 1);
|
||||
for (auto& v : weights) {
|
||||
v = dist(&lcg);
|
||||
}
|
||||
dmat->Info().weights_.HostVector() = weights;
|
||||
}
|
||||
|
||||
ASSERT_EQ(dmat->Info().feature_types.Size(), 1);
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
std::sort(x.begin(), x.end());
|
||||
@ -146,7 +157,8 @@ void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins) {
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchCategoricalFeatures) {
|
||||
TestCategoricalSketch(1000, 256, 32);
|
||||
TestCategoricalSketch(1000, 256, 32, false);
|
||||
TestCategoricalSketch(1000, 256, 32, true);
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchMultipleColumns) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user