Support categorical data in GPU weighted sketching. (#6508)
This commit is contained in:
@@ -220,10 +220,12 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
|
||||
}
|
||||
|
||||
void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
Span<const float> weights, size_t begin, size_t end,
|
||||
MetaInfo const& info, size_t begin, size_t end,
|
||||
SketchContainer* sketch_container, int num_cuts_per_feature,
|
||||
size_t num_columns,
|
||||
bool is_ranking, Span<bst_group_t const> d_group_ptr) {
|
||||
auto weights = info.weights_.ConstDeviceSpan();
|
||||
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
@@ -267,9 +269,10 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
batch_it, dummy_is_valid,
|
||||
0, sorted_entries.size(),
|
||||
&cuts_ptr, &column_sizes_scan);
|
||||
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries,
|
||||
&column_sizes_scan);
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
|
||||
// Extract cuts
|
||||
sketch_container->Push(dh::ToSpan(sorted_entries),
|
||||
@@ -309,7 +312,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
|
||||
info.group_ptr_.cend());
|
||||
ProcessWeightedBatch(
|
||||
device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end,
|
||||
device, batch, dmat->Info(), begin, end,
|
||||
&sketch_container,
|
||||
num_cuts_per_feature,
|
||||
dmat->Info().num_col_,
|
||||
|
||||
Reference in New Issue
Block a user