Support categorical data in GPU weighted sketching. (#6508)
This commit is contained in:
parent
5c8ccf4455
commit
886486a519
@ -220,10 +220,12 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ProcessWeightedBatch(int device, const SparsePage& page,
|
void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||||
Span<const float> weights, size_t begin, size_t end,
|
MetaInfo const& info, size_t begin, size_t end,
|
||||||
SketchContainer* sketch_container, int num_cuts_per_feature,
|
SketchContainer* sketch_container, int num_cuts_per_feature,
|
||||||
size_t num_columns,
|
size_t num_columns,
|
||||||
bool is_ranking, Span<bst_group_t const> d_group_ptr) {
|
bool is_ranking, Span<bst_group_t const> d_group_ptr) {
|
||||||
|
auto weights = info.weights_.ConstDeviceSpan();
|
||||||
|
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
const auto& host_data = page.data.ConstHostVector();
|
const auto& host_data = page.data.ConstHostVector();
|
||||||
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||||
@ -267,9 +269,10 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
|||||||
batch_it, dummy_is_valid,
|
batch_it, dummy_is_valid,
|
||||||
0, sorted_entries.size(),
|
0, sorted_entries.size(),
|
||||||
&cuts_ptr, &column_sizes_scan);
|
&cuts_ptr, &column_sizes_scan);
|
||||||
|
|
||||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
|
||||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||||
|
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries,
|
||||||
|
&column_sizes_scan);
|
||||||
|
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||||
|
|
||||||
// Extract cuts
|
// Extract cuts
|
||||||
sketch_container->Push(dh::ToSpan(sorted_entries),
|
sketch_container->Push(dh::ToSpan(sorted_entries),
|
||||||
@ -309,7 +312,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
|||||||
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
|
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
|
||||||
info.group_ptr_.cend());
|
info.group_ptr_.cend());
|
||||||
ProcessWeightedBatch(
|
ProcessWeightedBatch(
|
||||||
device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end,
|
device, batch, dmat->Info(), begin, end,
|
||||||
&sketch_container,
|
&sketch_container,
|
||||||
num_cuts_per_feature,
|
num_cuts_per_feature,
|
||||||
dmat->Info().num_col_,
|
dmat->Info().num_col_,
|
||||||
|
|||||||
@ -122,10 +122,21 @@ TEST(HistUtil, DeviceSketchCategoricalAsNumeric) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins) {
|
void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins, bool weighted) {
|
||||||
auto x = GenerateRandomCategoricalSingleColumn(n, num_categories);
|
auto x = GenerateRandomCategoricalSingleColumn(n, num_categories);
|
||||||
auto dmat = GetDMatrixFromData(x, n, 1);
|
auto dmat = GetDMatrixFromData(x, n, 1);
|
||||||
dmat->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
dmat->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
||||||
|
|
||||||
|
if (weighted) {
|
||||||
|
std::vector<float> weights(n, 0);
|
||||||
|
SimpleLCG lcg;
|
||||||
|
SimpleRealUniformDistribution<float> dist(0, 1);
|
||||||
|
for (auto& v : weights) {
|
||||||
|
v = dist(&lcg);
|
||||||
|
}
|
||||||
|
dmat->Info().weights_.HostVector() = weights;
|
||||||
|
}
|
||||||
|
|
||||||
ASSERT_EQ(dmat->Info().feature_types.Size(), 1);
|
ASSERT_EQ(dmat->Info().feature_types.Size(), 1);
|
||||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins);
|
auto cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||||
std::sort(x.begin(), x.end());
|
std::sort(x.begin(), x.end());
|
||||||
@ -146,7 +157,8 @@ void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(HistUtil, DeviceSketchCategoricalFeatures) {
|
TEST(HistUtil, DeviceSketchCategoricalFeatures) {
|
||||||
TestCategoricalSketch(1000, 256, 32);
|
TestCategoricalSketch(1000, 256, 32, false);
|
||||||
|
TestCategoricalSketch(1000, 256, 32, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HistUtil, DeviceSketchMultipleColumns) {
|
TEST(HistUtil, DeviceSketchMultipleColumns) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user