Merge extract cuts into QuantileContainer. (#6125)

* Use pruning for initial summary construction.
This commit is contained in:
Jiaming Yuan
2020-09-18 16:36:39 +08:00
committed by GitHub
parent cc82ca167a
commit e319b63f9e
6 changed files with 171 additions and 210 deletions

View File

@@ -33,79 +33,6 @@ namespace common {
constexpr float SketchContainer::kFactor;
namespace detail {
// Count the entries in each column and exclusive scan
void ExtractCutsSparse(int device, common::Span<SketchContainer::OffsetT const> cuts_ptr,
Span<Entry const> sorted_data,
Span<size_t const> column_sizes_scan,
Span<SketchEntry> out_cuts) {
dh::LaunchN(device, out_cuts.size(), [=] __device__(size_t idx) {
// Each thread is responsible for obtaining one cut from the sorted input
size_t column_idx = dh::SegmentId(cuts_ptr, idx);
size_t column_size =
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
size_t num_available_cuts = cuts_ptr[column_idx + 1] - cuts_ptr[column_idx];
size_t cut_idx = idx - cuts_ptr[column_idx];
Span<Entry const> column_entries =
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
size_t rank = (column_entries.size() * cut_idx) /
static_cast<float>(num_available_cuts);
out_cuts[idx] = WQSketch::Entry(rank, rank + 1, 1,
column_entries[rank].fvalue);
});
}
void ExtractWeightedCutsSparse(int device,
common::Span<SketchContainer::OffsetT const> cuts_ptr,
Span<Entry> sorted_data,
Span<float> weights_scan,
Span<size_t> column_sizes_scan,
Span<SketchEntry> cuts) {
dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) {
// Each thread is responsible for obtaining one cut from the sorted input
size_t column_idx = dh::SegmentId(cuts_ptr, idx);
size_t column_size =
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
size_t num_available_cuts = cuts_ptr[column_idx + 1] - cuts_ptr[column_idx];
size_t cut_idx = idx - cuts_ptr[column_idx];
Span<Entry> column_entries =
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
Span<float> column_weights_scan =
weights_scan.subspan(column_sizes_scan[column_idx], column_size);
float total_column_weight = column_weights_scan.back();
size_t sample_idx = 0;
if (cut_idx == 0) {
// First cut
sample_idx = 0;
} else if (cut_idx == num_available_cuts) {
// Last cut
sample_idx = column_entries.size() - 1;
} else if (num_available_cuts == column_size) {
// There are less samples available than our buffer
// Take every available sample
sample_idx = cut_idx;
} else {
bst_float rank = (total_column_weight * cut_idx) /
static_cast<float>(num_available_cuts);
sample_idx = thrust::upper_bound(thrust::seq,
column_weights_scan.begin(),
column_weights_scan.end(),
rank) -
column_weights_scan.begin();
sample_idx =
max(static_cast<size_t>(0),
min(sample_idx, column_entries.size() - 1));
}
// repeated values will be filtered out later.
bst_float rmin = sample_idx > 0 ? column_weights_scan[sample_idx - 1] : 0.0f;
bst_float rmax = column_weights_scan[sample_idx];
cuts[idx] = WQSketch::Entry(rmin, rmax, rmax - rmin,
column_entries[sample_idx].fvalue);
});
}
size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows) {
double eps = 1.0 / (WQSketch::kFactor * max_bins);
size_t dummy_nlevel;
@@ -220,19 +147,16 @@ void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end,
&cuts_ptr, &column_sizes_scan);
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
detail::ExtractCutsSparse(device, d_cuts_ptr, dh::ToSpan(sorted_entries),
dh::ToSpan(column_sizes_scan), dh::ToSpan(cuts));
// add cuts into sketches
sketch_container->Push(dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan),
d_cuts_ptr, h_cuts_ptr.back());
sorted_entries.clear();
sorted_entries.shrink_to_fit();
CHECK_EQ(sorted_entries.capacity(), 0);
CHECK_NE(cuts_ptr.Size(), 0);
sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
}
void ProcessWeightedBatch(int device, const SparsePage& page,
@@ -285,18 +209,14 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
&cuts_ptr, &column_sizes_scan);
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
// Extract cuts
detail::ExtractWeightedCutsSparse(device, d_cuts_ptr,
dh::ToSpan(sorted_entries),
dh::ToSpan(temp_weights),
dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts));
// add cuts into sketches
sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
sketch_container->Push(dh::ToSpan(sorted_entries),
dh::ToSpan(column_sizes_scan), d_cuts_ptr,
h_cuts_ptr.back(), dh::ToSpan(temp_weights));
sorted_entries.clear();
sorted_entries.shrink_to_fit();
}
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,