diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index ebd38b7ae..f44f416a1 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -33,79 +33,6 @@ namespace common { constexpr float SketchContainer::kFactor; namespace detail { - -// Count the entries in each column and exclusive scan -void ExtractCutsSparse(int device, common::Span cuts_ptr, - Span sorted_data, - Span column_sizes_scan, - Span out_cuts) { - dh::LaunchN(device, out_cuts.size(), [=] __device__(size_t idx) { - // Each thread is responsible for obtaining one cut from the sorted input - size_t column_idx = dh::SegmentId(cuts_ptr, idx); - size_t column_size = - column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx]; - size_t num_available_cuts = cuts_ptr[column_idx + 1] - cuts_ptr[column_idx]; - size_t cut_idx = idx - cuts_ptr[column_idx]; - Span column_entries = - sorted_data.subspan(column_sizes_scan[column_idx], column_size); - size_t rank = (column_entries.size() * cut_idx) / - static_cast(num_available_cuts); - out_cuts[idx] = WQSketch::Entry(rank, rank + 1, 1, - column_entries[rank].fvalue); - }); -} - -void ExtractWeightedCutsSparse(int device, - common::Span cuts_ptr, - Span sorted_data, - Span weights_scan, - Span column_sizes_scan, - Span cuts) { - dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) { - // Each thread is responsible for obtaining one cut from the sorted input - size_t column_idx = dh::SegmentId(cuts_ptr, idx); - size_t column_size = - column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx]; - size_t num_available_cuts = cuts_ptr[column_idx + 1] - cuts_ptr[column_idx]; - size_t cut_idx = idx - cuts_ptr[column_idx]; - - Span column_entries = - sorted_data.subspan(column_sizes_scan[column_idx], column_size); - - Span column_weights_scan = - weights_scan.subspan(column_sizes_scan[column_idx], column_size); - float total_column_weight = column_weights_scan.back(); - size_t sample_idx = 0; - if (cut_idx == 0) { - // First cut - sample_idx = 0; - } else if (cut_idx == num_available_cuts) { - // Last cut - sample_idx = column_entries.size() - 1; - } else if (num_available_cuts == column_size) { - // There are less samples available than our buffer - // Take every available sample - sample_idx = cut_idx; - } else { - bst_float rank = (total_column_weight * cut_idx) / - static_cast(num_available_cuts); - sample_idx = thrust::upper_bound(thrust::seq, - column_weights_scan.begin(), - column_weights_scan.end(), - rank) - - column_weights_scan.begin(); - sample_idx = - max(static_cast(0), - min(sample_idx, column_entries.size() - 1)); - } - // repeated values will be filtered out later. - bst_float rmin = sample_idx > 0 ? column_weights_scan[sample_idx - 1] : 0.0f; - bst_float rmax = column_weights_scan[sample_idx]; - cuts[idx] = WQSketch::Entry(rmin, rmax, rmax - rmin, - column_entries[sample_idx].fvalue); - }); -} - size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows) { double eps = 1.0 / (WQSketch::kFactor * max_bins); size_t dummy_nlevel; @@ -220,19 +147,16 @@ void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end, &cuts_ptr, &column_sizes_scan); auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); - dh::caching_device_vector cuts(h_cuts_ptr.back()); auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan(); - CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size()); - detail::ExtractCutsSparse(device, d_cuts_ptr, dh::ToSpan(sorted_entries), - dh::ToSpan(column_sizes_scan), dh::ToSpan(cuts)); // add cuts into sketches + sketch_container->Push(dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), + d_cuts_ptr, h_cuts_ptr.back()); sorted_entries.clear(); sorted_entries.shrink_to_fit(); CHECK_EQ(sorted_entries.capacity(), 0); CHECK_NE(cuts_ptr.Size(), 0); - sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts); } void ProcessWeightedBatch(int device, const SparsePage& page, @@ -285,18 +209,14 @@ void ProcessWeightedBatch(int device, const SparsePage& page, &cuts_ptr, &column_sizes_scan); auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); - dh::caching_device_vector cuts(h_cuts_ptr.back()); auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan(); // Extract cuts - detail::ExtractWeightedCutsSparse(device, d_cuts_ptr, - dh::ToSpan(sorted_entries), - dh::ToSpan(temp_weights), - dh::ToSpan(column_sizes_scan), - dh::ToSpan(cuts)); - - // add cuts into sketches - sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts); + sketch_container->Push(dh::ToSpan(sorted_entries), + dh::ToSpan(column_sizes_scan), d_cuts_ptr, + h_cuts_ptr.back(), dh::ToSpan(temp_weights)); + sorted_entries.clear(); + sorted_entries.shrink_to_fit(); } HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index f1034040c..7047716bf 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -28,37 +28,6 @@ struct EntryCompareOp { } }; -/** - * \brief Extracts the cuts from sorted data. - * - * \param device The device. - * \param cuts_ptr Column pointers to CSC structured cuts - * \param sorted_data Sorted entries in segments of columns - * \param column_sizes_scan Describes the boundaries of column segments in sorted data - * \param out_cuts Output cut values - */ -void ExtractCutsSparse(int device, common::Span cuts_ptr, - Span sorted_data, - Span column_sizes_scan, - Span out_cuts); - -/** - * \brief Extracts the cuts from sorted data, considering weights. - * - * \param device The device. - * \param cuts_ptr Column pointers to CSC structured cuts - * \param sorted_data Sorted entries in segments of columns. - * \param weights_scan Inclusive scan of weights for each entry in sorted_data. - * \param column_sizes_scan Describes the boundaries of column segments in sorted data. - * \param cuts Output cuts. - */ -void ExtractWeightedCutsSparse(int device, - common::Span cuts_ptr, - Span sorted_data, - Span weights_scan, - Span column_sizes_scan, - Span cuts); - // Get column size from adapter batch and for output cuts. template void GetColumnSizesScan(int device, size_t num_columns, size_t num_cuts_per_feature, @@ -186,16 +155,12 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns, auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan(); - dh::caching_device_vector cuts(h_cuts_ptr.back()); // Extract the cuts from all columns concurrently - detail::ExtractCutsSparse(device, d_cuts_ptr, - dh::ToSpan(sorted_entries), - dh::ToSpan(column_sizes_scan), - dh::ToSpan(cuts)); + sketch_container->Push(dh::ToSpan(sorted_entries), + dh::ToSpan(column_sizes_scan), d_cuts_ptr, + h_cuts_ptr.back()); sorted_entries.clear(); sorted_entries.shrink_to_fit(); - - sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts); } template @@ -263,16 +228,11 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan(); // Extract cuts - dh::caching_device_vector cuts(h_cuts_ptr.back()); - detail::ExtractWeightedCutsSparse(device, d_cuts_ptr, - dh::ToSpan(sorted_entries), - dh::ToSpan(temp_weights), - dh::ToSpan(column_sizes_scan), - dh::ToSpan(cuts)); + sketch_container->Push(dh::ToSpan(sorted_entries), + dh::ToSpan(column_sizes_scan), d_cuts_ptr, + h_cuts_ptr.back(), dh::ToSpan(temp_weights)); sorted_entries.clear(); sorted_entries.shrink_to_fit(); - // add cuts into sketches - sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts); } /* @@ -324,4 +284,4 @@ void AdapterDeviceSketch(Batch batch, int num_bins, } // namespace common } // namespace xgboost -#endif // COMMON_HIST_UTIL_CUH_ \ No newline at end of file +#endif // COMMON_HIST_UTIL_CUH_ diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 52d0e37e9..3b4d846ac 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -24,29 +24,83 @@ using WQSketch = HostSketchContainer::WQSketch; using SketchEntry = WQSketch::Entry; // Algorithm 4 in XGBoost's paper, using binary search to find i. -__device__ SketchEntry BinarySearchQuery(Span const& entries, float rank) { - assert(entries.size() >= 2); +template +__device__ SketchEntry BinarySearchQuery(EntryIter beg, EntryIter end, float rank) { + assert(end - beg >= 2); rank *= 2; - if (rank < entries.front().rmin + entries.front().rmax) { - return entries.front(); + auto front = *beg; + if (rank < front.rmin + front.rmax) { + return *beg; } - if (rank >= entries.back().rmin + entries.back().rmax) { - return entries.back(); + auto back = *(end - 1); + if (rank >= back.rmin + back.rmax) { + return back; } - auto begin = dh::MakeTransformIterator( - entries.begin(), [=] __device__(SketchEntry const &entry) { + auto search_begin = dh::MakeTransformIterator( + beg, [=] __device__(SketchEntry const &entry) { return entry.rmin + entry.rmax; }); - auto end = begin + entries.size(); - auto i = thrust::upper_bound(thrust::seq, begin + 1, end - 1, rank) - begin - 1; - if (rank < entries[i].RMinNext() + entries[i+1].RMaxPrev()) { - return entries[i]; + auto search_end = search_begin + (end - beg); + auto i = + thrust::upper_bound(thrust::seq, search_begin + 1, search_end - 1, rank) - + search_begin - 1; + if (rank < (*(beg + i)).RMinNext() + (*(beg + i + 1)).RMaxPrev()) { + return *(beg + i); } else { - return entries[i+1]; + return *(beg + i + 1); } } +template +void PruneImpl(int device, + common::Span cuts_ptr, + Span sorted_data, + Span columns_ptr_in, // could be ptr for data or cuts + Span out_cuts, + ToSketchEntry to_sketch_entry) { + dh::LaunchN(device, out_cuts.size(), [=] __device__(size_t idx) { + size_t column_id = dh::SegmentId(cuts_ptr, idx); + auto out_column = out_cuts.subspan( + cuts_ptr[column_id], cuts_ptr[column_id + 1] - cuts_ptr[column_id]); + auto in_column = sorted_data.subspan(columns_ptr_in[column_id], + columns_ptr_in[column_id + 1] - + columns_ptr_in[column_id]); + auto to = cuts_ptr[column_id + 1] - cuts_ptr[column_id]; + idx -= cuts_ptr[column_id]; + auto front = to_sketch_entry(0ul, in_column, column_id); + auto back = to_sketch_entry(in_column.size() - 1, in_column, column_id); + + if (in_column.size() <= to) { + // cut idx equals sample idx + out_column[idx] = to_sketch_entry(idx, in_column, column_id); + return; + } + // 1 thread for each output. See A.4 for detail. + auto d_out = out_column; + if (idx == 0) { + d_out.front() = front; + return; + } + if (idx == to - 1) { + d_out.back() = back; + return; + } + + float w = back.rmin - front.rmax; + assert(w != 0); + auto budget = static_cast(d_out.size()); + assert(budget != 0); + auto q = ((static_cast(idx) * w) / (static_cast(to) - 1.0f) + front.rmax); + auto it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] __device__(size_t idx) { + auto e = to_sketch_entry(idx, in_column, column_id); + return e; + }); + d_out[idx] = BinarySearchQuery(it, it + in_column.size(), q); + }); +} + template void CopyTo(Span out, Span src) { CHECK_EQ(out.size(), src.size()); @@ -249,27 +303,58 @@ void MergeImpl(int32_t device, Span const &d_x, }); } -void SketchContainer::Push(common::Span cuts_ptr, - dh::caching_device_vector* entries) { - timer_.Start(__func__); - dh::safe_cuda(cudaSetDevice(device_)); - // Copy or merge the new cuts, pruning is performed during `MakeCuts`. - if (this->Current().size() == 0) { - CHECK_EQ(this->columns_ptr_.Size(), cuts_ptr.size()); - // See thrust issue 1030, THRUST_CPP_DIALECT is not correctly defined so - // move constructor is not used. - this->Current().swap(*entries); - CHECK_EQ(entries->size(), 0); - auto d_cuts_ptr = this->columns_ptr_.DevicePointer(); - thrust::copy(thrust::device, cuts_ptr.data(), - cuts_ptr.data() + cuts_ptr.size(), d_cuts_ptr); +void SketchContainer::Push(Span entries, Span columns_ptr, + common::Span cuts_ptr, + size_t total_cuts, Span weights) { + Span out; + dh::caching_device_vector cuts; + bool first_window = this->Current().empty(); + if (!first_window) { + cuts.resize(total_cuts); + out = dh::ToSpan(cuts); } else { - auto d_entries = dh::ToSpan(*entries); - this->Merge(cuts_ptr, d_entries); - this->FixError(); + this->Current().resize(total_cuts); + out = dh::ToSpan(this->Current()); + } + + if (weights.empty()) { + auto to_sketch_entry = [] __device__(size_t sample_idx, + Span const &column, + size_t) { + float rmin = sample_idx; + float rmax = sample_idx + 1; + return SketchEntry{rmin, rmax, 1, column[sample_idx].fvalue}; + }; // NOLINT + PruneImpl(device_, cuts_ptr, entries, columns_ptr, out, + to_sketch_entry); + } else { + auto to_sketch_entry = [weights, columns_ptr] __device__( + size_t sample_idx, + Span const &column, + size_t column_id) { + Span column_weights_scan = + weights.subspan(columns_ptr[column_id], column.size()); + float rmin = sample_idx > 0 ? column_weights_scan[sample_idx - 1] : 0.0f; + float rmax = column_weights_scan[sample_idx]; + float wmin = rmax - rmin; + wmin = wmin < 0 ? kRtEps : wmin; // GPU scan can generate floating error. + return SketchEntry{rmin, rmax, wmin, column[sample_idx].fvalue}; + }; // NOLINT + PruneImpl(device_, cuts_ptr, entries, columns_ptr, out, + to_sketch_entry); + } + + if (!first_window) { + CHECK_EQ(this->columns_ptr_.Size(), cuts_ptr.size()); + this->Merge(cuts_ptr, out); + this->FixError(); + } else { + this->columns_ptr_.SetDevice(device_); + this->columns_ptr_.Resize(cuts_ptr.size()); + + auto d_cuts_ptr = this->columns_ptr_.DeviceSpan(); + CopyTo(d_cuts_ptr, cuts_ptr); } - CHECK_NE(this->columns_ptr_.Size(), 0); - timer_.Stop(__func__); } size_t SketchContainer::Unique() { @@ -317,41 +402,11 @@ void SketchContainer::Prune(size_t to) { auto d_columns_ptr_out = new_columns_ptr.ConstDeviceSpan(); auto out = dh::ToSpan(this->Other()); auto in = dh::ToSpan(this->Current()); - dh::LaunchN(0, to_total, [=] __device__(size_t idx) { - size_t column_id = dh::SegmentId(d_columns_ptr_out, idx); - auto out_column = out.subspan(d_columns_ptr_out[column_id], - d_columns_ptr_out[column_id + 1] - - d_columns_ptr_out[column_id]); - auto in_column = in.subspan(d_columns_ptr_in[column_id], - d_columns_ptr_in[column_id + 1] - - d_columns_ptr_in[column_id]); - idx -= d_columns_ptr_out[column_id]; - // Input has lesser columns than `to`, just copy them to the output. This is correct - // as the new output size is calculated based on both the size of `to` and current - // column. - if (in_column.size() <= to) { - out_column[idx] = in_column[idx]; - return; - } - // 1 thread for each output. See A.4 for detail. - auto entries = in_column; - auto d_out = out_column; - if (idx == 0) { - d_out.front() = entries.front(); - return; - } - if (idx == to - 1) { - d_out.back() = entries.back(); - return; - } - - float w = entries.back().rmin - entries.front().rmax; - assert(w != 0); - auto budget = static_cast(d_out.size()); - assert(budget != 0); - auto q = ((idx * w) / (to - 1) + entries.front().rmax); - d_out[idx] = BinarySearchQuery(entries, q); - }); + auto no_op = [] __device__(size_t sample_idx, + Span const &entries, + size_t) { return entries[sample_idx]; }; // NOLINT + PruneImpl(device_, d_columns_ptr_out, in, d_columns_ptr_in, out, + no_op); this->columns_ptr_.HostVector() = new_columns_ptr.HostVector(); this->Alternate(); timer_.Stop(__func__); diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index cd5833914..00cc19329 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -96,13 +96,21 @@ class SketchContainer { * addition inside `RMinNext` and subtraction in `RMaxPrev`. */ void FixError(); - /* \brief Push a CSC structured cut matrix. */ - void Push(common::Span cuts_ptr, - dh::caching_device_vector* entries); + /* \brief Push sorted entries. + * + * \param entries Sorted entries. + * \param columns_ptr CSC pointer for entries. + * \param cuts_ptr CSC pointer for cuts. + * \param total_cuts Total number of cuts, equal to the back of cuts_ptr. + * \param weights (optional) data weights. + */ + void Push(Span entries, Span columns_ptr, + common::Span cuts_ptr, size_t total_cuts, + Span weights = {}); /* \brief Prune the quantile structure. * - * \param to The maximum size of pruned quantile. If the size of quantile structure is - * already less than `to`, then no operation is performed. + * \param to The maximum size of pruned quantile. If the size of quantile + * structure is already less than `to`, then no operation is performed. */ void Prune(size_t to); /* \brief Merge another set of sketch. @@ -135,8 +143,8 @@ struct SketchUnique { return a.value - b.value == 0; } }; -} // anonymous detail +} // namespace detail } // namespace common } // namespace xgboost -#endif // XGBOOST_COMMON_QUANTILE_CUH_ \ No newline at end of file +#endif // XGBOOST_COMMON_QUANTILE_CUH_ diff --git a/tests/cpp/common/test_device_helpers.cu b/tests/cpp/common/test_device_helpers.cu index d857a0118..006c036d3 100644 --- a/tests/cpp/common/test_device_helpers.cu +++ b/tests/cpp/common/test_device_helpers.cu @@ -29,6 +29,24 @@ TEST(AtomicAdd, SizeT) { TestAtomicSizeT(); } +void TestSegmentID() { + std::vector segments{0, 1, 3}; + thrust::device_vector d_segments(segments); + auto s_segments = dh::ToSpan(d_segments); + dh::LaunchN(0, 1, [=]__device__(size_t idx) { + auto id = dh::SegmentId(s_segments, 0); + SPAN_CHECK(id == 0); + id = dh::SegmentId(s_segments, 1); + SPAN_CHECK(id == 1); + id = dh::SegmentId(s_segments, 2); + SPAN_CHECK(id == 1); + }); +} + +TEST(SegmentID, Basic) { + TestSegmentID(); +} + TEST(SegmentedUnique, Basic) { std::vector values{0.1f, 0.2f, 0.3f, 0.62448811531066895f, 0.62448811531066895f, 0.4f}; std::vector segments{0, 3, 6}; diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index f7c7e22e3..5279135a9 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -9,11 +9,11 @@ namespace common { TEST(GPUQuantile, Basic) { constexpr size_t kRows = 1000, kCols = 100, kBins = 256; SketchContainer sketch(kBins, kCols, kRows, 0); - dh::caching_device_vector entries; + dh::caching_device_vector entries; dh::device_vector cuts_ptr(kCols+1); thrust::fill(cuts_ptr.begin(), cuts_ptr.end(), 0); // Push empty - sketch.Push(dh::ToSpan(cuts_ptr), &entries); + sketch.Push(dh::ToSpan(entries), dh::ToSpan(cuts_ptr), dh::ToSpan(cuts_ptr), 0); ASSERT_EQ(sketch.Data().size(), 0); }