Merge extract cuts into QuantileContainer. (#6125)
* Use pruning for initial summary construction.
This commit is contained in:
@@ -33,79 +33,6 @@ namespace common {
|
||||
constexpr float SketchContainer::kFactor;
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Count the entries in each column and exclusive scan
|
||||
void ExtractCutsSparse(int device, common::Span<SketchContainer::OffsetT const> cuts_ptr,
|
||||
Span<Entry const> sorted_data,
|
||||
Span<size_t const> column_sizes_scan,
|
||||
Span<SketchEntry> out_cuts) {
|
||||
dh::LaunchN(device, out_cuts.size(), [=] __device__(size_t idx) {
|
||||
// Each thread is responsible for obtaining one cut from the sorted input
|
||||
size_t column_idx = dh::SegmentId(cuts_ptr, idx);
|
||||
size_t column_size =
|
||||
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
|
||||
size_t num_available_cuts = cuts_ptr[column_idx + 1] - cuts_ptr[column_idx];
|
||||
size_t cut_idx = idx - cuts_ptr[column_idx];
|
||||
Span<Entry const> column_entries =
|
||||
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
|
||||
size_t rank = (column_entries.size() * cut_idx) /
|
||||
static_cast<float>(num_available_cuts);
|
||||
out_cuts[idx] = WQSketch::Entry(rank, rank + 1, 1,
|
||||
column_entries[rank].fvalue);
|
||||
});
|
||||
}
|
||||
|
||||
void ExtractWeightedCutsSparse(int device,
|
||||
common::Span<SketchContainer::OffsetT const> cuts_ptr,
|
||||
Span<Entry> sorted_data,
|
||||
Span<float> weights_scan,
|
||||
Span<size_t> column_sizes_scan,
|
||||
Span<SketchEntry> cuts) {
|
||||
dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) {
|
||||
// Each thread is responsible for obtaining one cut from the sorted input
|
||||
size_t column_idx = dh::SegmentId(cuts_ptr, idx);
|
||||
size_t column_size =
|
||||
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
|
||||
size_t num_available_cuts = cuts_ptr[column_idx + 1] - cuts_ptr[column_idx];
|
||||
size_t cut_idx = idx - cuts_ptr[column_idx];
|
||||
|
||||
Span<Entry> column_entries =
|
||||
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
|
||||
|
||||
Span<float> column_weights_scan =
|
||||
weights_scan.subspan(column_sizes_scan[column_idx], column_size);
|
||||
float total_column_weight = column_weights_scan.back();
|
||||
size_t sample_idx = 0;
|
||||
if (cut_idx == 0) {
|
||||
// First cut
|
||||
sample_idx = 0;
|
||||
} else if (cut_idx == num_available_cuts) {
|
||||
// Last cut
|
||||
sample_idx = column_entries.size() - 1;
|
||||
} else if (num_available_cuts == column_size) {
|
||||
// There are less samples available than our buffer
|
||||
// Take every available sample
|
||||
sample_idx = cut_idx;
|
||||
} else {
|
||||
bst_float rank = (total_column_weight * cut_idx) /
|
||||
static_cast<float>(num_available_cuts);
|
||||
sample_idx = thrust::upper_bound(thrust::seq,
|
||||
column_weights_scan.begin(),
|
||||
column_weights_scan.end(),
|
||||
rank) -
|
||||
column_weights_scan.begin();
|
||||
sample_idx =
|
||||
max(static_cast<size_t>(0),
|
||||
min(sample_idx, column_entries.size() - 1));
|
||||
}
|
||||
// repeated values will be filtered out later.
|
||||
bst_float rmin = sample_idx > 0 ? column_weights_scan[sample_idx - 1] : 0.0f;
|
||||
bst_float rmax = column_weights_scan[sample_idx];
|
||||
cuts[idx] = WQSketch::Entry(rmin, rmax, rmax - rmin,
|
||||
column_entries[sample_idx].fvalue);
|
||||
});
|
||||
}
|
||||
|
||||
size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows) {
|
||||
double eps = 1.0 / (WQSketch::kFactor * max_bins);
|
||||
size_t dummy_nlevel;
|
||||
@@ -220,19 +147,16 @@ void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end,
|
||||
&cuts_ptr, &column_sizes_scan);
|
||||
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
|
||||
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
|
||||
|
||||
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
|
||||
detail::ExtractCutsSparse(device, d_cuts_ptr, dh::ToSpan(sorted_entries),
|
||||
dh::ToSpan(column_sizes_scan), dh::ToSpan(cuts));
|
||||
|
||||
// add cuts into sketches
|
||||
sketch_container->Push(dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan),
|
||||
d_cuts_ptr, h_cuts_ptr.back());
|
||||
sorted_entries.clear();
|
||||
sorted_entries.shrink_to_fit();
|
||||
CHECK_EQ(sorted_entries.capacity(), 0);
|
||||
CHECK_NE(cuts_ptr.Size(), 0);
|
||||
sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
|
||||
}
|
||||
|
||||
void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
@@ -285,18 +209,14 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
&cuts_ptr, &column_sizes_scan);
|
||||
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
|
||||
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
|
||||
|
||||
// Extract cuts
|
||||
detail::ExtractWeightedCutsSparse(device, d_cuts_ptr,
|
||||
dh::ToSpan(sorted_entries),
|
||||
dh::ToSpan(temp_weights),
|
||||
dh::ToSpan(column_sizes_scan),
|
||||
dh::ToSpan(cuts));
|
||||
|
||||
// add cuts into sketches
|
||||
sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
|
||||
sketch_container->Push(dh::ToSpan(sorted_entries),
|
||||
dh::ToSpan(column_sizes_scan), d_cuts_ptr,
|
||||
h_cuts_ptr.back(), dh::ToSpan(temp_weights));
|
||||
sorted_entries.clear();
|
||||
sorted_entries.shrink_to_fit();
|
||||
}
|
||||
|
||||
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
|
||||
Reference in New Issue
Block a user