Support categorical data in GPU sketching. (#6137)

This commit is contained in:
Jiaming Yuan
2020-09-21 13:53:06 +08:00
committed by GitHub
parent c932fb50a1
commit 210c131ce7
6 changed files with 196 additions and 62 deletions

View File

@@ -24,6 +24,7 @@
#include "hist_util.cuh"
#include "math.h" // NOLINT
#include "quantile.h"
#include "categorical.h"
#include "xgboost/host_device_vector.h"
@@ -121,11 +122,59 @@ void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
return a.index == b.index;
});
}
struct IsCatOp {
XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
};
void RemoveDuplicatedCategories(
int32_t device, MetaInfo const &info, Span<bst_row_t> d_cuts_ptr,
dh::device_vector<Entry> *p_sorted_entries,
dh::caching_device_vector<size_t> const &column_sizes_scan) {
auto d_feature_types = info.feature_types.ConstDeviceSpan();
if (!info.feature_types.Empty() &&
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
IsCatOp{})) {
auto& sorted_entries = *p_sorted_entries;
// Removing duplicated entries in categorical features.
dh::caching_device_vector<size_t> new_column_scan(column_sizes_scan.size());
dh::SegmentedUnique(column_sizes_scan.data().get(),
column_sizes_scan.data().get() +
column_sizes_scan.size(),
sorted_entries.begin(), sorted_entries.end(),
new_column_scan.data().get(), sorted_entries.begin(),
[=] __device__(Entry const &l, Entry const &r) {
if (l.index == r.index) {
if (IsCat(d_feature_types, l.index)) {
return l.fvalue == r.fvalue;
}
}
return false;
});
// Renew the column scan and cut scan based on categorical data.
dh::caching_device_vector<SketchContainer::OffsetT> new_cuts_size(
info.num_col_ + 1);
auto d_new_cuts_size = dh::ToSpan(new_cuts_size);
auto d_new_columns_ptr = dh::ToSpan(new_column_scan);
CHECK_EQ(new_column_scan.size(), new_cuts_size.size());
dh::LaunchN(device, new_column_scan.size() - 1, [=] __device__(size_t idx) {
if (IsCat(d_feature_types, idx)) {
d_new_cuts_size[idx] =
d_new_columns_ptr[idx + 1] - d_new_columns_ptr[idx];
} else {
d_new_cuts_size[idx] = d_cuts_ptr[idx] - d_cuts_ptr[idx];
}
});
thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(),
new_cuts_size.cend(), d_cuts_ptr.data());
}
}
} // namespace detail
void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end,
SketchContainer *sketch_container, int num_cuts_per_feature,
size_t num_columns) {
void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
size_t begin, size_t end, SketchContainer *sketch_container,
int num_cuts_per_feature, size_t num_columns) {
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
@@ -145,9 +194,10 @@ void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end,
batch_it, dummy_is_valid,
0, sorted_entries.size(),
&cuts_ptr, &column_sizes_scan);
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries,
column_sizes_scan);
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
// add cuts into sketches
@@ -221,6 +271,8 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
size_t sketch_batch_num_elements) {
dmat->Info().feature_types.SetDevice(device);
dmat->Info().feature_types.ConstDevicePointer(); // pull to device early
// Configure batch size based on available memory
bool has_weights = dmat->Info().weights_.Size() > 0;
size_t num_cuts_per_feature =
@@ -233,7 +285,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
device, num_cuts_per_feature, has_weights);
HistogramCuts cuts;
SketchContainer sketch_container(max_bins, dmat->Info().num_col_,
SketchContainer sketch_container(dmat->Info().feature_types, max_bins, dmat->Info().num_col_,
dmat->Info().num_row_, device);
dmat->Info().weights_.SetDevice(device);
@@ -253,8 +305,8 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
dmat->Info().num_col_,
is_ranking, dh::ToSpan(groups));
} else {
ProcessBatch(device, batch, begin, end, &sketch_container, num_cuts_per_feature,
dmat->Info().num_col_);
ProcessBatch(device, dmat->Info(), batch, begin, end, &sketch_container,
num_cuts_per_feature, dmat->Info().num_col_);
}
}
}