Support categorical data in GPU weighted sketching. (#6508)

This commit is contained in:
Jiaming Yuan
2020-12-16 14:23:28 +08:00
committed by GitHub
parent 5c8ccf4455
commit 886486a519
2 changed files with 21 additions and 6 deletions

View File

@@ -220,10 +220,12 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
}
void ProcessWeightedBatch(int device, const SparsePage& page,
Span<const float> weights, size_t begin, size_t end,
MetaInfo const& info, size_t begin, size_t end,
SketchContainer* sketch_container, int num_cuts_per_feature,
size_t num_columns,
bool is_ranking, Span<bst_group_t const> d_group_ptr) {
auto weights = info.weights_.ConstDeviceSpan();
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
@@ -267,9 +269,10 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
batch_it, dummy_is_valid,
0, sorted_entries.size(),
&cuts_ptr, &column_sizes_scan);
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries,
&column_sizes_scan);
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
// Extract cuts
sketch_container->Push(dh::ToSpan(sorted_entries),
@@ -309,7 +312,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
info.group_ptr_.cend());
ProcessWeightedBatch(
device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end,
device, batch, dmat->Info(), begin, end,
&sketch_container,
num_cuts_per_feature,
dmat->Info().num_col_,