Implement sketching with Hessian on GPU. (#9399)

- Prepare for implementing approx on GPU.
- Unify the code path between weighted and uniform sketching on DMatrix.
This commit is contained in:
Jiaming Yuan 2023-07-24 15:43:03 +08:00 committed by GitHub
parent 851cba931e
commit a196443a07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 446 additions and 230 deletions

View File

@ -185,10 +185,10 @@ class MetaInfo {
return data_split_mode == DataSplitMode::kRow; return data_split_mode == DataSplitMode::kRow;
} }
/*! \brief Whether the data is split column-wise. */ /** @brief Whether the data is split column-wise. */
bool IsColumnSplit() const { bool IsColumnSplit() const { return data_split_mode == DataSplitMode::kCol; }
return data_split_mode == DataSplitMode::kCol; /** @brief Whether this is a learning to rank data. */
} bool IsRanking() const { return !group_ptr_.empty(); }
/*! /*!
* \brief A convenient method to check if we are doing vertical federated learning, which requires * \brief A convenient method to check if we are doing vertical federated learning, which requires
@ -249,7 +249,7 @@ struct BatchParam {
/** /**
* \brief Hessian, used for sketching with future approx implementation. * \brief Hessian, used for sketching with future approx implementation.
*/ */
common::Span<float> hess; common::Span<float const> hess;
/** /**
* \brief Whether should we force DMatrix to regenerate the batch. Only used for * \brief Whether should we force DMatrix to regenerate the batch. Only used for
* GHistIndex. * GHistIndex.
@ -279,7 +279,7 @@ struct BatchParam {
* Get batch with sketch weighted by hessian. The batch will be regenerated if the * Get batch with sketch weighted by hessian. The batch will be regenerated if the
* span is changed, so caller should keep the span for each iteration. * span is changed, so caller should keep the span for each iteration.
*/ */
BatchParam(bst_bin_t max_bin, common::Span<float> hessian, bool regenerate) BatchParam(bst_bin_t max_bin, common::Span<float const> hessian, bool regenerate)
: max_bin{max_bin}, hess{hessian}, regen{regenerate} {} : max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
[[nodiscard]] bool ParamNotEqual(BatchParam const& other) const { [[nodiscard]] bool ParamNotEqual(BatchParam const& other) const {

View File

@ -49,11 +49,12 @@
#ifndef XGBOOST_HOST_DEVICE_VECTOR_H_ #ifndef XGBOOST_HOST_DEVICE_VECTOR_H_
#define XGBOOST_HOST_DEVICE_VECTOR_H_ #define XGBOOST_HOST_DEVICE_VECTOR_H_
#include <initializer_list> #include <xgboost/context.h> // for DeviceOrd
#include <vector> #include <xgboost/span.h> // for Span
#include <type_traits>
#include "span.h" #include <initializer_list>
#include <type_traits>
#include <vector>
namespace xgboost { namespace xgboost {
@ -133,6 +134,7 @@ class HostDeviceVector {
GPUAccess DeviceAccess() const; GPUAccess DeviceAccess() const;
void SetDevice(int device) const; void SetDevice(int device) const;
void SetDevice(DeviceOrd device) const;
void Resize(size_t new_size, T v = T()); void Resize(size_t new_size, T v = T());

View File

@ -12,8 +12,8 @@
#include "../data/gradient_index.h" // for GHistIndexMatrix #include "../data/gradient_index.h" // for GHistIndexMatrix
#include "quantile.h" #include "quantile.h"
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/context.h" // Context #include "xgboost/context.h" // for Context
#include "xgboost/data.h" // SparsePage, SortedCSCPage #include "xgboost/data.h" // for SparsePage, SortedCSCPage
#if defined(XGBOOST_MM_PREFETCH_PRESENT) #if defined(XGBOOST_MM_PREFETCH_PRESENT)
#include <xmmintrin.h> #include <xmmintrin.h>
@ -30,7 +30,7 @@ HistogramCuts::HistogramCuts() {
} }
HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins, bool use_sorted, HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins, bool use_sorted,
Span<float> const hessian) { Span<float const> hessian) {
HistogramCuts out; HistogramCuts out;
auto const &info = m->Info(); auto const &info = m->Info();
auto n_threads = ctx->Threads(); auto n_threads = ctx->Threads();

View File

@ -19,14 +19,13 @@
#include <vector> #include <vector>
#include "categorical.h" #include "categorical.h"
#include "cuda_context.cuh" // for CUDAContext
#include "device_helpers.cuh" #include "device_helpers.cuh"
#include "hist_util.cuh" #include "hist_util.cuh"
#include "hist_util.h" #include "hist_util.h"
#include "math.h" // NOLINT
#include "quantile.h" #include "quantile.h"
#include "xgboost/host_device_vector.h" #include "xgboost/host_device_vector.h"
namespace xgboost::common { namespace xgboost::common {
constexpr float SketchContainer::kFactor; constexpr float SketchContainer::kFactor;
@ -109,22 +108,19 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_row_t num_ro
return std::min(sketch_batch_num_elements, kIntMax); return std::min(sketch_batch_num_elements, kIntMax);
} }
void SortByWeight(dh::device_vector<float>* weights, void SortByWeight(dh::device_vector<float>* weights, dh::device_vector<Entry>* sorted_entries) {
dh::device_vector<Entry>* sorted_entries) {
// Sort both entries and wegihts. // Sort both entries and wegihts.
dh::XGBDeviceAllocator<char> alloc; dh::XGBDeviceAllocator<char> alloc;
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(), CHECK_EQ(weights->size(), sorted_entries->size());
sorted_entries->end(), weights->begin(), thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(), sorted_entries->end(),
detail::EntryCompareOp()); weights->begin(), detail::EntryCompareOp());
// Scan weights // Scan weights
dh::XGBCachingDeviceAllocator<char> caching; dh::XGBCachingDeviceAllocator<char> caching;
thrust::inclusive_scan_by_key(thrust::cuda::par(caching), thrust::inclusive_scan_by_key(
sorted_entries->begin(), sorted_entries->end(), thrust::cuda::par(caching), sorted_entries->begin(), sorted_entries->end(), weights->begin(),
weights->begin(), weights->begin(), weights->begin(),
[=] __device__(const Entry& a, const Entry& b) { [=] __device__(const Entry& a, const Entry& b) { return a.index == b.index; });
return a.index == b.index;
});
} }
void RemoveDuplicatedCategories(int32_t device, MetaInfo const& info, Span<bst_row_t> d_cuts_ptr, void RemoveDuplicatedCategories(int32_t device, MetaInfo const& info, Span<bst_row_t> d_cuts_ptr,
@ -200,159 +196,170 @@ void RemoveDuplicatedCategories(int32_t device, MetaInfo const& info, Span<bst_r
} }
} // namespace detail } // namespace detail
void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page, void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo const& info,
size_t begin, size_t end, SketchContainer *sketch_container, std::size_t begin, std::size_t end,
int num_cuts_per_feature, size_t num_columns) { SketchContainer* sketch_container, // <- output sketch
dh::XGBCachingDeviceAllocator<char> alloc; int num_cuts_per_feature, common::Span<float const> sample_weight) {
dh::device_vector<Entry> sorted_entries; dh::device_vector<Entry> sorted_entries;
if (page.data.DeviceCanRead()) { if (page.data.DeviceCanRead()) {
const auto& device_data = page.data.ConstDevicePointer(); // direct copy if data is already on device
sorted_entries = dh::device_vector<Entry>(device_data + begin, device_data + end); auto const& d_data = page.data.ConstDevicePointer();
sorted_entries = dh::device_vector<Entry>(d_data + begin, d_data + end);
} else { } else {
const auto& host_data = page.data.ConstHostVector(); const auto& h_data = page.data.ConstHostVector();
sorted_entries = dh::device_vector<Entry>(host_data.begin() + begin, sorted_entries = dh::device_vector<Entry>(h_data.begin() + begin, h_data.begin() + end);
host_data.begin() + end); }
bst_row_t base_rowid = page.base_rowid;
dh::device_vector<float> entry_weight;
auto cuctx = ctx->CUDACtx();
if (!sample_weight.empty()) {
// Expand sample weight into entry weight.
CHECK_EQ(sample_weight.size(), info.num_row_);
entry_weight.resize(sorted_entries.size());
auto d_temp_weight = dh::ToSpan(entry_weight);
page.offset.SetDevice(ctx->Device());
auto row_ptrs = page.offset.ConstDeviceSpan();
thrust::for_each_n(cuctx->CTP(), thrust::make_counting_iterator(0ul), entry_weight.size(),
[=] __device__(std::size_t idx) {
std::size_t element_idx = idx + begin;
std::size_t ridx = dh::SegmentId(row_ptrs, element_idx);
d_temp_weight[idx] = sample_weight[ridx + base_rowid];
});
detail::SortByWeight(&entry_weight, &sorted_entries);
} else {
thrust::sort(cuctx->CTP(), sorted_entries.begin(), sorted_entries.end(),
detail::EntryCompareOp());
} }
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr; HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
dh::caching_device_vector<size_t> column_sizes_scan; dh::caching_device_vector<size_t> column_sizes_scan;
data::IsValidFunctor dummy_is_valid(std::numeric_limits<float>::quiet_NaN()); data::IsValidFunctor dummy_is_valid(std::numeric_limits<float>::quiet_NaN());
auto batch_it = dh::MakeTransformIterator<data::COOTuple>( auto batch_it = dh::MakeTransformIterator<data::COOTuple>(
sorted_entries.data().get(), sorted_entries.data().get(), [] __device__(Entry const& e) -> data::COOTuple {
[] __device__(Entry const &e) -> data::COOTuple { return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size.
return {0, e.index, e.fvalue}; // row_idx is not needed for scanning column size.
}); });
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature, detail::GetColumnSizesScan(ctx->Ordinal(), info.num_col_, num_cuts_per_feature,
IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr, IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr,
&column_sizes_scan); &column_sizes_scan);
auto d_cuts_ptr = cuts_ptr.DeviceSpan(); auto d_cuts_ptr = cuts_ptr.DeviceSpan();
if (sketch_container->HasCategorical()) { if (sketch_container->HasCategorical()) {
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, nullptr, auto p_weight = entry_weight.empty() ? nullptr : &entry_weight;
detail::RemoveDuplicatedCategories(ctx->Ordinal(), info, d_cuts_ptr, &sorted_entries, p_weight,
&column_sizes_scan); &column_sizes_scan);
} }
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size()); CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
// add cuts into sketches // Add cuts into sketches
sketch_container->Push(dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), sketch_container->Push(dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
d_cuts_ptr, h_cuts_ptr.back()); h_cuts_ptr.back(), dh::ToSpan(entry_weight));
sorted_entries.clear(); sorted_entries.clear();
sorted_entries.shrink_to_fit(); sorted_entries.shrink_to_fit();
CHECK_EQ(sorted_entries.capacity(), 0); CHECK_EQ(sorted_entries.capacity(), 0);
CHECK_NE(cuts_ptr.Size(), 0); CHECK_NE(cuts_ptr.Size(), 0);
} }
void ProcessWeightedBatch(int device, const SparsePage& page, // Unify group weight, Hessian, and sample weight into sample weight.
MetaInfo const& info, size_t begin, size_t end, [[nodiscard]] Span<float const> UnifyWeight(CUDAContext const* cuctx, MetaInfo const& info,
SketchContainer* sketch_container, int num_cuts_per_feature, common::Span<float const> hessian,
size_t num_columns, HostDeviceVector<float>* p_out_weight) {
bool is_ranking, Span<bst_group_t const> d_group_ptr) { if (hessian.empty()) {
auto weights = info.weights_.ConstDeviceSpan(); if (info.IsRanking() && !info.weights_.Empty()) {
common::Span<float const> group_weight = info.weights_.ConstDeviceSpan();
dh::XGBCachingDeviceAllocator<char> alloc; dh::device_vector<bst_group_t> group_ptr(info.group_ptr_);
const auto& host_data = page.data.ConstHostVector(); auto d_group_ptr = dh::ToSpan(group_ptr);
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin, CHECK_GE(d_group_ptr.size(), 2) << "Must have at least 1 group for ranking.";
host_data.begin() + end); auto d_weight = info.weights_.ConstDeviceSpan();
CHECK_EQ(d_weight.size(), d_group_ptr.size() - 1)
// Binary search to assign weights to each element
dh::device_vector<float> temp_weights(sorted_entries.size());
auto d_temp_weights = temp_weights.data().get();
page.offset.SetDevice(device);
auto row_ptrs = page.offset.ConstDeviceSpan();
size_t base_rowid = page.base_rowid;
if (is_ranking) {
CHECK_GE(d_group_ptr.size(), 2)
<< "Must have at least 1 group for ranking.";
CHECK_EQ(weights.size(), d_group_ptr.size() - 1)
<< "Weight size should equal to number of groups."; << "Weight size should equal to number of groups.";
dh::LaunchN(temp_weights.size(), [=] __device__(size_t idx) { p_out_weight->Resize(info.num_row_);
size_t element_idx = idx + begin; auto d_weight_out = p_out_weight->DeviceSpan();
size_t ridx = dh::SegmentId(row_ptrs, element_idx);
bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx + base_rowid); thrust::for_each_n(cuctx->CTP(), thrust::make_counting_iterator(0ul), d_weight_out.size(),
d_temp_weights[idx] = weights[group_idx]; [=] XGBOOST_DEVICE(std::size_t i) {
auto gidx = dh::SegmentId(d_group_ptr, i);
d_weight_out[i] = d_weight[gidx];
});
return p_out_weight->ConstDeviceSpan();
} else {
return info.weights_.ConstDeviceSpan();
}
}
// sketch with hessian as weight
p_out_weight->Resize(info.num_row_);
auto d_weight_out = p_out_weight->DeviceSpan();
if (!info.weights_.Empty()) {
// merge sample weight with hessian
auto d_weight = info.weights_.ConstDeviceSpan();
if (info.IsRanking()) {
dh::device_vector<bst_group_t> group_ptr(info.group_ptr_);
CHECK_EQ(hessian.size(), d_weight_out.size());
auto d_group_ptr = dh::ToSpan(group_ptr);
CHECK_GE(d_group_ptr.size(), 2) << "Must have at least 1 group for ranking.";
CHECK_EQ(d_weight.size(), d_group_ptr.size() - 1)
<< "Weight size should equal to number of groups.";
thrust::for_each_n(cuctx->CTP(), thrust::make_counting_iterator(0ul), hessian.size(),
[=] XGBOOST_DEVICE(std::size_t i) {
d_weight_out[i] = d_weight[dh::SegmentId(d_group_ptr, i)] * hessian(i);
}); });
} else { } else {
dh::LaunchN(temp_weights.size(), [=] __device__(size_t idx) { CHECK_EQ(hessian.size(), info.num_row_);
size_t element_idx = idx + begin; CHECK_EQ(hessian.size(), d_weight.size());
size_t ridx = dh::SegmentId(row_ptrs, element_idx); CHECK_EQ(hessian.size(), d_weight_out.size());
d_temp_weights[idx] = weights[ridx + base_rowid]; thrust::for_each_n(
}); cuctx->CTP(), thrust::make_counting_iterator(0ul), hessian.size(),
[=] XGBOOST_DEVICE(std::size_t i) { d_weight_out[i] = d_weight[i] * hessian(i); });
} }
detail::SortByWeight(&temp_weights, &sorted_entries); } else {
// copy hessian as weight
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr; CHECK_EQ(d_weight_out.size(), hessian.size());
dh::caching_device_vector<size_t> column_sizes_scan; dh::safe_cuda(cudaMemcpyAsync(d_weight_out.data(), hessian.data(), hessian.size_bytes(),
data::IsValidFunctor dummy_is_valid(std::numeric_limits<float>::quiet_NaN()); cudaMemcpyDefault));
auto batch_it = dh::MakeTransformIterator<data::COOTuple>(
sorted_entries.data().get(),
[] __device__(Entry const &e) -> data::COOTuple {
return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size.
});
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature,
IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr,
&column_sizes_scan);
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
if (sketch_container->HasCategorical()) {
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, &temp_weights,
&column_sizes_scan);
} }
return d_weight_out;
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
// Extract cuts
sketch_container->Push(dh::ToSpan(sorted_entries),
dh::ToSpan(column_sizes_scan), d_cuts_ptr,
h_cuts_ptr.back(), dh::ToSpan(temp_weights));
sorted_entries.clear();
sorted_entries.shrink_to_fit();
} }
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin,
size_t sketch_batch_num_elements) { Span<float const> hessian,
dmat->Info().feature_types.SetDevice(device); std::size_t sketch_batch_num_elements) {
dmat->Info().feature_types.ConstDevicePointer(); // pull to device early auto const& info = p_fmat->Info();
bool has_weight = !info.weights_.Empty();
info.feature_types.SetDevice(ctx->Device());
HostDeviceVector<float> weight;
weight.SetDevice(ctx->Device());
// Configure batch size based on available memory // Configure batch size based on available memory
bool has_weights = dmat->Info().weights_.Size() > 0; std::size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(max_bin, info.num_row_);
size_t num_cuts_per_feature =
detail::RequiredSampleCutsPerColumn(max_bins, dmat->Info().num_row_);
sketch_batch_num_elements = detail::SketchBatchNumElements( sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements, sketch_batch_num_elements, info.num_row_, info.num_col_, info.num_nonzero_, ctx->Ordinal(),
dmat->Info().num_row_, num_cuts_per_feature, has_weight);
dmat->Info().num_col_,
dmat->Info().num_nonzero_, CUDAContext const* cuctx = ctx->CUDACtx();
device, num_cuts_per_feature, has_weights);
info.weights_.SetDevice(ctx->Device());
auto d_weight = UnifyWeight(cuctx, info, hessian, &weight);
HistogramCuts cuts; HistogramCuts cuts;
SketchContainer sketch_container(dmat->Info().feature_types, max_bins, dmat->Info().num_col_, SketchContainer sketch_container(info.feature_types, max_bin, info.num_col_, info.num_row_,
dmat->Info().num_row_, device); ctx->Ordinal());
CHECK_EQ(has_weight || !hessian.empty(), !d_weight.empty());
for (const auto& page : p_fmat->GetBatches<SparsePage>()) {
std::size_t page_nnz = page.data.Size();
for (auto begin = 0ull; begin < page_nnz; begin += sketch_batch_num_elements) {
std::size_t end =
std::min(page_nnz, static_cast<std::size_t>(begin + sketch_batch_num_elements));
ProcessWeightedBatch(ctx, page, info, begin, end, &sketch_container, num_cuts_per_feature,
d_weight);
}
}
dmat->Info().weights_.SetDevice(device); sketch_container.MakeCuts(&cuts, p_fmat->Info().IsColumnSplit());
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
size_t batch_nnz = batch.data.Size();
auto const& info = dmat->Info();
for (auto begin = 0ull; begin < batch_nnz; begin += sketch_batch_num_elements) {
size_t end = std::min(batch_nnz, static_cast<std::size_t>(begin + sketch_batch_num_elements));
if (has_weights) {
bool is_ranking = HostSketchContainer::UseGroup(dmat->Info());
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
info.group_ptr_.cend());
ProcessWeightedBatch(
device, batch, dmat->Info(), begin, end,
&sketch_container,
num_cuts_per_feature,
dmat->Info().num_col_,
is_ranking, dh::ToSpan(groups));
} else {
ProcessBatch(device, dmat->Info(), batch, begin, end, &sketch_container,
num_cuts_per_feature, dmat->Info().num_col_);
}
}
}
sketch_container.MakeCuts(&cuts, dmat->Info().IsColumnSplit());
return cuts; return cuts;
} }
} // namespace xgboost::common } // namespace xgboost::common

View File

@ -11,14 +11,13 @@
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include "../data/device_adapter.cuh" #include "../data/adapter.h" // for IsValidFunctor
#include "device_helpers.cuh" #include "device_helpers.cuh"
#include "hist_util.h" #include "hist_util.h"
#include "quantile.cuh" #include "quantile.cuh"
#include "timer.h" #include "xgboost/span.h" // for IterSpan
namespace xgboost { namespace xgboost::common {
namespace common {
namespace cuda { namespace cuda {
/** /**
* copy and paste of the host version, we can't make it a __host__ __device__ function as * copy and paste of the host version, we can't make it a __host__ __device__ function as
@ -246,10 +245,35 @@ void RemoveDuplicatedCategories(int32_t device, MetaInfo const& info, Span<bst_r
dh::caching_device_vector<size_t>* p_column_sizes_scan); dh::caching_device_vector<size_t>* p_column_sizes_scan);
} // namespace detail } // namespace detail
// Compute sketch on DMatrix. /**
// sketch_batch_num_elements 0 means autodetect. Only modify this for testing. * @brief Compute sketch on DMatrix with GPU and Hessian as weight.
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, *
size_t sketch_batch_num_elements = 0); * @param ctx Runtime context
* @param p_fmat Training feature matrix
* @param max_bin Maximum number of bins for each feature
* @param hessian Hessian vector.
* @param sketch_batch_num_elements 0 means autodetect. Only modify this for testing.
*
* @return Quantile cuts
*/
HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin,
Span<float const> hessian,
std::size_t sketch_batch_num_elements = 0);
/**
* @brief Compute sketch on DMatrix with GPU.
*
* @param ctx Runtime context
* @param p_fmat Training feature matrix
* @param max_bin Maximum number of bins for each feature
* @param sketch_batch_num_elements 0 means autodetect. Only modify this for testing.
*
* @return Quantile cuts
*/
inline HistogramCuts DeviceSketch(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin,
std::size_t sketch_batch_num_elements = 0) {
return DeviceSketchWithHessian(ctx, p_fmat, max_bin, {}, sketch_batch_num_elements);
}
template <typename AdapterBatch> template <typename AdapterBatch>
void ProcessSlidingWindow(AdapterBatch const &batch, MetaInfo const &info, void ProcessSlidingWindow(AdapterBatch const &batch, MetaInfo const &info,
@ -417,7 +441,5 @@ void AdapterDeviceSketch(Batch batch, int num_bins,
} }
} }
} }
} // namespace common } // namespace xgboost::common
} // namespace xgboost
#endif // COMMON_HIST_UTIL_CUH_ #endif // COMMON_HIST_UTIL_CUH_

View File

@ -172,7 +172,7 @@ class HistogramCuts {
* but consumes more memory. * but consumes more memory.
*/ */
HistogramCuts SketchOnDMatrix(Context const* ctx, DMatrix* m, bst_bin_t max_bins, HistogramCuts SketchOnDMatrix(Context const* ctx, DMatrix* m, bst_bin_t max_bins,
bool use_sorted = false, Span<float> const hessian = {}); bool use_sorted = false, Span<float const> hessian = {});
enum BinTypeSize : uint8_t { enum BinTypeSize : uint8_t {
kUint8BinsTypeSize = 1, kUint8BinsTypeSize = 1,

View File

@ -168,6 +168,9 @@ bool HostDeviceVector<T>::DeviceCanWrite() const {
template <typename T> template <typename T>
void HostDeviceVector<T>::SetDevice(int) const {} void HostDeviceVector<T>::SetDevice(int) const {}
template <typename T>
void HostDeviceVector<T>::SetDevice(DeviceOrd) const {}
// explicit instantiations are required, as HostDeviceVector isn't header-only // explicit instantiations are required, as HostDeviceVector isn't header-only
template class HostDeviceVector<bst_float>; template class HostDeviceVector<bst_float>;
template class HostDeviceVector<double>; template class HostDeviceVector<double>;

View File

@ -394,6 +394,11 @@ void HostDeviceVector<T>::SetDevice(int device) const {
impl_->SetDevice(device); impl_->SetDevice(device);
} }
template <typename T>
void HostDeviceVector<T>::SetDevice(DeviceOrd device) const {
impl_->SetDevice(device.ordinal);
}
template <typename T> template <typename T>
void HostDeviceVector<T>::Resize(size_t new_size, T v) { void HostDeviceVector<T>::Resize(size_t new_size, T v) {
impl_->Resize(new_size, v); impl_->Resize(new_size, v);

View File

@ -131,7 +131,7 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchP
monitor_.Start("Quantiles"); monitor_.Start("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistogramCuts. // Create the quantile sketches for the dmatrix and initialize HistogramCuts.
row_stride = GetRowStride(dmat); row_stride = GetRowStride(dmat);
cuts_ = common::DeviceSketch(ctx->gpu_id, dmat, param.max_bin); cuts_ = common::DeviceSketch(ctx, dmat, param.max_bin);
monitor_.Stop("Quantiles"); monitor_.Stop("Quantiles");
monitor_.Start("InitCompressedData"); monitor_.Start("InitCompressedData");

View File

@ -21,7 +21,7 @@ GHistIndexMatrix::GHistIndexMatrix() : columns_{std::make_unique<common::ColumnM
GHistIndexMatrix::GHistIndexMatrix(Context const *ctx, DMatrix *p_fmat, bst_bin_t max_bins_per_feat, GHistIndexMatrix::GHistIndexMatrix(Context const *ctx, DMatrix *p_fmat, bst_bin_t max_bins_per_feat,
double sparse_thresh, bool sorted_sketch, double sparse_thresh, bool sorted_sketch,
common::Span<float> hess) common::Span<float const> hess)
: max_numeric_bins_per_feat{max_bins_per_feat} { : max_numeric_bins_per_feat{max_bins_per_feat} {
CHECK(p_fmat->SingleColBlock()); CHECK(p_fmat->SingleColBlock());
// We use sorted sketching for approx tree method since it's more efficient in // We use sorted sketching for approx tree method since it's more efficient in

View File

@ -160,7 +160,7 @@ class GHistIndexMatrix {
* \brief Constrcutor for SimpleDMatrix. * \brief Constrcutor for SimpleDMatrix.
*/ */
GHistIndexMatrix(Context const* ctx, DMatrix* x, bst_bin_t max_bins_per_feat, GHistIndexMatrix(Context const* ctx, DMatrix* x, bst_bin_t max_bins_per_feat,
double sparse_thresh, bool sorted_sketch, common::Span<float> hess = {}); double sparse_thresh, bool sorted_sketch, common::Span<float const> hess = {});
/** /**
* \brief Constructor for Iterative DMatrix. Initialize basic information and prepare * \brief Constructor for Iterative DMatrix. Initialize basic information and prepare
* for push batch. * for push batch.

View File

@ -25,8 +25,8 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
cache_info_.erase(id); cache_info_.erase(id);
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_); MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
std::unique_ptr<common::HistogramCuts> cuts; std::unique_ptr<common::HistogramCuts> cuts;
cuts = std::make_unique<common::HistogramCuts>( cuts =
common::DeviceSketch(ctx->gpu_id, this, param.max_bin, 0)); std::make_unique<common::HistogramCuts>(common::DeviceSketch(ctx, this, param.max_bin, 0));
this->InitializeSparsePage(ctx); // reset after use. this->InitializeSparsePage(ctx); // reset after use.
row_stride = GetRowStride(this); row_stride = GetRowStride(this);

View File

@ -3,17 +3,22 @@
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <xgboost/base.h> // for bst_bin_t
#include <xgboost/c_api.h> #include <xgboost/c_api.h>
#include <xgboost/data.h> #include <xgboost/data.h>
#include <algorithm> #include <algorithm> // for transform
#include <cmath> #include <cmath> // for floor
#include <cstddef> // for size_t
#include <limits> // for numeric_limits
#include <string> // for string, to_string
#include <tuple> // for tuple, make_tuple
#include <vector> // for vector
#include "../../../include/xgboost/logging.h" #include "../../../include/xgboost/logging.h"
#include "../../../src/common/device_helpers.cuh" #include "../../../src/common/device_helpers.cuh"
#include "../../../src/common/hist_util.cuh" #include "../../../src/common/hist_util.cuh"
#include "../../../src/common/hist_util.h" #include "../../../src/common/hist_util.h"
#include "../../../src/common/math.h"
#include "../../../src/data/device_adapter.cuh" #include "../../../src/data/device_adapter.cuh"
#include "../../../src/data/simple_dmatrix.h" #include "../../../src/data/simple_dmatrix.h"
#include "../data/test_array_interface.h" #include "../data/test_array_interface.h"
@ -21,8 +26,7 @@
#include "../helpers.h" #include "../helpers.h"
#include "test_hist_util.h" #include "test_hist_util.h"
namespace xgboost { namespace xgboost::common {
namespace common {
template <typename AdapterT> template <typename AdapterT>
HistogramCuts GetHostCuts(Context const* ctx, AdapterT* adapter, int num_bins, float missing) { HistogramCuts GetHostCuts(Context const* ctx, AdapterT* adapter, int num_bins, float missing) {
@ -32,16 +36,17 @@ HistogramCuts GetHostCuts(Context const* ctx, AdapterT* adapter, int num_bins, f
} }
TEST(HistUtil, DeviceSketch) { TEST(HistUtil, DeviceSketch) {
auto ctx = MakeCUDACtx(0);
int num_columns = 1; int num_columns = 1;
int num_bins = 4; int num_bins = 4;
std::vector<float> x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 7.0f, -1.0f}; std::vector<float> x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 7.0f, -1.0f};
int num_rows = x.size(); int num_rows = x.size();
auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); auto device_cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
Context ctx; Context cpu_ctx;
HistogramCuts host_cuts = SketchOnDMatrix(&ctx, dmat.get(), num_bins); HistogramCuts host_cuts = SketchOnDMatrix(&cpu_ctx, dmat.get(), num_bins);
EXPECT_EQ(device_cuts.Values(), host_cuts.Values()); EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs()); EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs());
@ -65,6 +70,7 @@ TEST(HistUtil, SketchBatchNumElements) {
} }
TEST(HistUtil, DeviceSketchMemory) { TEST(HistUtil, DeviceSketchMemory) {
auto ctx = MakeCUDACtx(0);
int num_columns = 100; int num_columns = 100;
int num_rows = 1000; int num_rows = 1000;
int num_bins = 256; int num_bins = 256;
@ -73,7 +79,7 @@ TEST(HistUtil, DeviceSketchMemory) {
dh::GlobalMemoryLogger().Clear(); dh::GlobalMemoryLogger().Clear();
ConsoleLogger::Configure({{"verbosity", "3"}}); ConsoleLogger::Configure({{"verbosity", "3"}});
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); auto device_cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
size_t bytes_required = detail::RequiredMemory( size_t bytes_required = detail::RequiredMemory(
num_rows, num_columns, num_rows * num_columns, num_bins, false); num_rows, num_columns, num_rows * num_columns, num_bins, false);
@ -83,6 +89,7 @@ TEST(HistUtil, DeviceSketchMemory) {
} }
TEST(HistUtil, DeviceSketchWeightsMemory) { TEST(HistUtil, DeviceSketchWeightsMemory) {
auto ctx = MakeCUDACtx(0);
int num_columns = 100; int num_columns = 100;
int num_rows = 1000; int num_rows = 1000;
int num_bins = 256; int num_bins = 256;
@ -92,7 +99,7 @@ TEST(HistUtil, DeviceSketchWeightsMemory) {
dh::GlobalMemoryLogger().Clear(); dh::GlobalMemoryLogger().Clear();
ConsoleLogger::Configure({{"verbosity", "3"}}); ConsoleLogger::Configure({{"verbosity", "3"}});
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); auto device_cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
ConsoleLogger::Configure({{"verbosity", "0"}}); ConsoleLogger::Configure({{"verbosity", "0"}});
size_t bytes_required = detail::RequiredMemory( size_t bytes_required = detail::RequiredMemory(
@ -102,42 +109,43 @@ TEST(HistUtil, DeviceSketchWeightsMemory) {
} }
TEST(HistUtil, DeviceSketchDeterminism) { TEST(HistUtil, DeviceSketchDeterminism) {
auto ctx = MakeCUDACtx(0);
int num_rows = 500; int num_rows = 500;
int num_columns = 5; int num_columns = 5;
int num_bins = 256; int num_bins = 256;
auto x = GenerateRandom(num_rows, num_columns); auto x = GenerateRandom(num_rows, num_columns);
auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
auto reference_sketch = DeviceSketch(0, dmat.get(), num_bins); auto reference_sketch = DeviceSketch(&ctx, dmat.get(), num_bins);
size_t constexpr kRounds{ 100 }; size_t constexpr kRounds{ 100 };
for (size_t r = 0; r < kRounds; ++r) { for (size_t r = 0; r < kRounds; ++r) {
auto new_sketch = DeviceSketch(0, dmat.get(), num_bins); auto new_sketch = DeviceSketch(&ctx, dmat.get(), num_bins);
ASSERT_EQ(reference_sketch.Values(), new_sketch.Values()); ASSERT_EQ(reference_sketch.Values(), new_sketch.Values());
ASSERT_EQ(reference_sketch.MinValues(), new_sketch.MinValues()); ASSERT_EQ(reference_sketch.MinValues(), new_sketch.MinValues());
} }
} }
TEST(HistUtil, DeviceSketchCategoricalAsNumeric) { TEST(HistUtil, DeviceSketchCategoricalAsNumeric) {
int categorical_sizes[] = {2, 6, 8, 12}; auto ctx = MakeCUDACtx(0);
auto categorical_sizes = {2, 6, 8, 12};
int num_bins = 256; int num_bins = 256;
int sizes[] = {25, 100, 1000}; auto sizes = {25, 100, 1000};
for (auto n : sizes) { for (auto n : sizes) {
for (auto num_categories : categorical_sizes) { for (auto num_categories : categorical_sizes) {
auto x = GenerateRandomCategoricalSingleColumn(n, num_categories); auto x = GenerateRandomCategoricalSingleColumn(n, num_categories);
auto dmat = GetDMatrixFromData(x, n, 1); auto dmat = GetDMatrixFromData(x, n, 1);
auto cuts = DeviceSketch(0, dmat.get(), num_bins); auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
ValidateCuts(cuts, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins);
} }
} }
} }
TEST(HistUtil, DeviceSketchCategoricalFeatures) { TEST(HistUtil, DeviceSketchCategoricalFeatures) {
TestCategoricalSketch(1000, 256, 32, false, auto ctx = MakeCUDACtx(0);
[](DMatrix *p_fmat, int32_t num_bins) { TestCategoricalSketch(1000, 256, 32, false, [ctx](DMatrix* p_fmat, int32_t num_bins) {
return DeviceSketch(0, p_fmat, num_bins); return DeviceSketch(&ctx, p_fmat, num_bins);
}); });
TestCategoricalSketch(1000, 256, 32, true, TestCategoricalSketch(1000, 256, 32, true, [ctx](DMatrix* p_fmat, int32_t num_bins) {
[](DMatrix *p_fmat, int32_t num_bins) { return DeviceSketch(&ctx, p_fmat, num_bins);
return DeviceSketch(0, p_fmat, num_bins);
}); });
} }
@ -162,7 +170,8 @@ void TestMixedSketch() {
m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical); m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
m->Info().feature_types.HostVector().push_back(FeatureType::kNumerical); m->Info().feature_types.HostVector().push_back(FeatureType::kNumerical);
auto cuts = DeviceSketch(0, m.get(), n_bins); auto ctx = MakeCUDACtx(0);
auto cuts = DeviceSketch(&ctx, m.get(), n_bins);
ASSERT_EQ(cuts.Values().size(), n_bins + n_categories); ASSERT_EQ(cuts.Values().size(), n_bins + n_categories);
} }
@ -234,37 +243,40 @@ TEST(HistUtil, RemoveDuplicatedCategories) {
} }
TEST(HistUtil, DeviceSketchMultipleColumns) { TEST(HistUtil, DeviceSketchMultipleColumns) {
int bin_sizes[] = {2, 16, 256, 512}; auto ctx = MakeCUDACtx(0);
int sizes[] = {100, 1000, 1500}; auto bin_sizes = {2, 16, 256, 512};
auto sizes = {100, 1000, 1500};
int num_columns = 5; int num_columns = 5;
for (auto num_rows : sizes) { for (auto num_rows : sizes) {
auto x = GenerateRandom(num_rows, num_columns); auto x = GenerateRandom(num_rows, num_columns);
auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
for (auto num_bins : bin_sizes) { for (auto num_bins : bin_sizes) {
auto cuts = DeviceSketch(0, dmat.get(), num_bins); auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
ValidateCuts(cuts, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins);
} }
} }
} }
TEST(HistUtil, DeviceSketchMultipleColumnsWeights) { TEST(HistUtil, DeviceSketchMultipleColumnsWeights) {
int bin_sizes[] = {2, 16, 256, 512}; auto ctx = MakeCUDACtx(0);
int sizes[] = {100, 1000, 1500}; auto bin_sizes = {2, 16, 256, 512};
auto sizes = {100, 1000, 1500};
int num_columns = 5; int num_columns = 5;
for (auto num_rows : sizes) { for (auto num_rows : sizes) {
auto x = GenerateRandom(num_rows, num_columns); auto x = GenerateRandom(num_rows, num_columns);
auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
dmat->Info().weights_.HostVector() = GenerateRandomWeights(num_rows); dmat->Info().weights_.HostVector() = GenerateRandomWeights(num_rows);
for (auto num_bins : bin_sizes) { for (auto num_bins : bin_sizes) {
auto cuts = DeviceSketch(0, dmat.get(), num_bins); auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
ValidateCuts(cuts, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins);
} }
} }
} }
TEST(HistUitl, DeviceSketchWeights) { TEST(HistUitl, DeviceSketchWeights) {
int bin_sizes[] = {2, 16, 256, 512}; auto ctx = MakeCUDACtx(0);
int sizes[] = {100, 1000, 1500}; auto bin_sizes = {2, 16, 256, 512};
auto sizes = {100, 1000, 1500};
int num_columns = 5; int num_columns = 5;
for (auto num_rows : sizes) { for (auto num_rows : sizes) {
auto x = GenerateRandom(num_rows, num_columns); auto x = GenerateRandom(num_rows, num_columns);
@ -274,8 +286,8 @@ TEST(HistUitl, DeviceSketchWeights) {
h_weights.resize(num_rows); h_weights.resize(num_rows);
std::fill(h_weights.begin(), h_weights.end(), 1.0f); std::fill(h_weights.begin(), h_weights.end(), 1.0f);
for (auto num_bins : bin_sizes) { for (auto num_bins : bin_sizes) {
auto cuts = DeviceSketch(0, dmat.get(), num_bins); auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
auto wcuts = DeviceSketch(0, weighted_dmat.get(), num_bins); auto wcuts = DeviceSketch(&ctx, weighted_dmat.get(), num_bins);
ASSERT_EQ(cuts.MinValues(), wcuts.MinValues()); ASSERT_EQ(cuts.MinValues(), wcuts.MinValues());
ASSERT_EQ(cuts.Ptrs(), wcuts.Ptrs()); ASSERT_EQ(cuts.Ptrs(), wcuts.Ptrs());
ASSERT_EQ(cuts.Values(), wcuts.Values()); ASSERT_EQ(cuts.Values(), wcuts.Values());
@ -286,14 +298,15 @@ TEST(HistUitl, DeviceSketchWeights) {
} }
TEST(HistUtil, DeviceSketchBatches) { TEST(HistUtil, DeviceSketchBatches) {
auto ctx = MakeCUDACtx(0);
int num_bins = 256; int num_bins = 256;
int num_rows = 5000; int num_rows = 5000;
int batch_sizes[] = {0, 100, 1500, 6000}; auto batch_sizes = {0, 100, 1500, 6000};
int num_columns = 5; int num_columns = 5;
for (auto batch_size : batch_sizes) { for (auto batch_size : batch_sizes) {
auto x = GenerateRandom(num_rows, num_columns); auto x = GenerateRandom(num_rows, num_columns);
auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
auto cuts = DeviceSketch(0, dmat.get(), num_bins, batch_size); auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins, batch_size);
ValidateCuts(cuts, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins);
} }
@ -301,8 +314,8 @@ TEST(HistUtil, DeviceSketchBatches) {
size_t batches = 16; size_t batches = 16;
auto x = GenerateRandom(num_rows * batches, num_columns); auto x = GenerateRandom(num_rows * batches, num_columns);
auto dmat = GetDMatrixFromData(x, num_rows * batches, num_columns); auto dmat = GetDMatrixFromData(x, num_rows * batches, num_columns);
auto cuts_with_batches = DeviceSketch(0, dmat.get(), num_bins, num_rows); auto cuts_with_batches = DeviceSketch(&ctx, dmat.get(), num_bins, num_rows);
auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0); auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins, 0);
auto const& cut_values_batched = cuts_with_batches.Values(); auto const& cut_values_batched = cuts_with_batches.Values();
auto const& cut_values = cuts.Values(); auto const& cut_values = cuts.Values();
@ -313,15 +326,16 @@ TEST(HistUtil, DeviceSketchBatches) {
} }
TEST(HistUtil, DeviceSketchMultipleColumnsExternal) { TEST(HistUtil, DeviceSketchMultipleColumnsExternal) {
int bin_sizes[] = {2, 16, 256, 512}; auto ctx = MakeCUDACtx(0);
int sizes[] = {100, 1000, 1500}; auto bin_sizes = {2, 16, 256, 512};
auto sizes = {100, 1000, 1500};
int num_columns =5; int num_columns =5;
for (auto num_rows : sizes) { for (auto num_rows : sizes) {
auto x = GenerateRandom(num_rows, num_columns); auto x = GenerateRandom(num_rows, num_columns);
dmlc::TemporaryDirectory temp; dmlc::TemporaryDirectory temp;
auto dmat = GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, temp); auto dmat = GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, temp);
for (auto num_bins : bin_sizes) { for (auto num_bins : bin_sizes) {
auto cuts = DeviceSketch(0, dmat.get(), num_bins); auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
ValidateCuts(cuts, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins);
} }
} }
@ -329,8 +343,9 @@ TEST(HistUtil, DeviceSketchMultipleColumnsExternal) {
// See https://github.com/dmlc/xgboost/issues/5866. // See https://github.com/dmlc/xgboost/issues/5866.
TEST(HistUtil, DeviceSketchExternalMemoryWithWeights) { TEST(HistUtil, DeviceSketchExternalMemoryWithWeights) {
int bin_sizes[] = {2, 16, 256, 512}; auto ctx = MakeCUDACtx(0);
int sizes[] = {100, 1000, 1500}; auto bin_sizes = {2, 16, 256, 512};
auto sizes = {100, 1000, 1500};
int num_columns = 5; int num_columns = 5;
dmlc::TemporaryDirectory temp; dmlc::TemporaryDirectory temp;
for (auto num_rows : sizes) { for (auto num_rows : sizes) {
@ -338,7 +353,7 @@ TEST(HistUtil, DeviceSketchExternalMemoryWithWeights) {
auto dmat = GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, temp); auto dmat = GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, temp);
dmat->Info().weights_.HostVector() = GenerateRandomWeights(num_rows); dmat->Info().weights_.HostVector() = GenerateRandomWeights(num_rows);
for (auto num_bins : bin_sizes) { for (auto num_bins : bin_sizes) {
auto cuts = DeviceSketch(0, dmat.get(), num_bins); auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
ValidateCuts(cuts, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins);
} }
} }
@ -504,9 +519,9 @@ void TestCategoricalSketchAdapter(size_t n, size_t num_categories,
} }
TEST(HistUtil, AdapterDeviceSketchCategorical) { TEST(HistUtil, AdapterDeviceSketchCategorical) {
int categorical_sizes[] = {2, 6, 8, 12}; auto categorical_sizes = {2, 6, 8, 12};
int num_bins = 256; int num_bins = 256;
int sizes[] = {25, 100, 1000}; auto sizes = {25, 100, 1000};
for (auto n : sizes) { for (auto n : sizes) {
for (auto num_categories : categorical_sizes) { for (auto num_categories : categorical_sizes) {
auto x = GenerateRandomCategoricalSingleColumn(n, num_categories); auto x = GenerateRandomCategoricalSingleColumn(n, num_categories);
@ -521,8 +536,8 @@ TEST(HistUtil, AdapterDeviceSketchCategorical) {
} }
TEST(HistUtil, AdapterDeviceSketchMultipleColumns) { TEST(HistUtil, AdapterDeviceSketchMultipleColumns) {
int bin_sizes[] = {2, 16, 256, 512}; auto bin_sizes = {2, 16, 256, 512};
int sizes[] = {100, 1000, 1500}; auto sizes = {100, 1000, 1500};
int num_columns = 5; int num_columns = 5;
for (auto num_rows : sizes) { for (auto num_rows : sizes) {
auto x = GenerateRandom(num_rows, num_columns); auto x = GenerateRandom(num_rows, num_columns);
@ -538,7 +553,7 @@ TEST(HistUtil, AdapterDeviceSketchMultipleColumns) {
TEST(HistUtil, AdapterDeviceSketchBatches) { TEST(HistUtil, AdapterDeviceSketchBatches) {
int num_bins = 256; int num_bins = 256;
int num_rows = 5000; int num_rows = 5000;
int batch_sizes[] = {0, 100, 1500, 6000}; auto batch_sizes = {0, 100, 1500, 6000};
int num_columns = 5; int num_columns = 5;
for (auto batch_size : batch_sizes) { for (auto batch_size : batch_sizes) {
auto x = GenerateRandom(num_rows, num_columns); auto x = GenerateRandom(num_rows, num_columns);
@ -619,14 +634,15 @@ TEST(HistUtil, GetColumnSize) {
// Check sketching from adapter or DMatrix results in the same answer // Check sketching from adapter or DMatrix results in the same answer
// Consistency here is useful for testing and user experience // Consistency here is useful for testing and user experience
TEST(HistUtil, SketchingEquivalent) { TEST(HistUtil, SketchingEquivalent) {
int bin_sizes[] = {2, 16, 256, 512}; auto ctx = MakeCUDACtx(0);
int sizes[] = {100, 1000, 1500}; auto bin_sizes = {2, 16, 256, 512};
auto sizes = {100, 1000, 1500};
int num_columns = 5; int num_columns = 5;
for (auto num_rows : sizes) { for (auto num_rows : sizes) {
auto x = GenerateRandom(num_rows, num_columns); auto x = GenerateRandom(num_rows, num_columns);
auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
for (auto num_bins : bin_sizes) { for (auto num_bins : bin_sizes) {
auto dmat_cuts = DeviceSketch(0, dmat.get(), num_bins); auto dmat_cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
auto x_device = thrust::device_vector<float>(x); auto x_device = thrust::device_vector<float>(x);
auto adapter = AdapterFromData(x_device, num_rows, num_columns); auto adapter = AdapterFromData(x_device, num_rows, num_columns);
common::HistogramCuts adapter_cuts = MakeUnweightedCutsForTest( common::HistogramCuts adapter_cuts = MakeUnweightedCutsForTest(
@ -641,21 +657,25 @@ TEST(HistUtil, SketchingEquivalent) {
} }
TEST(HistUtil, DeviceSketchFromGroupWeights) { TEST(HistUtil, DeviceSketchFromGroupWeights) {
auto ctx = MakeCUDACtx(0);
size_t constexpr kRows = 3000, kCols = 200, kBins = 256; size_t constexpr kRows = 3000, kCols = 200, kBins = 256;
size_t constexpr kGroups = 10; size_t constexpr kGroups = 10;
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
// sketch with group weight
auto& h_weights = m->Info().weights_.HostVector(); auto& h_weights = m->Info().weights_.HostVector();
h_weights.resize(kRows); h_weights.resize(kGroups);
std::fill(h_weights.begin(), h_weights.end(), 1.0f); std::fill(h_weights.begin(), h_weights.end(), 1.0f);
std::vector<bst_group_t> groups(kGroups); std::vector<bst_group_t> groups(kGroups);
for (size_t i = 0; i < kGroups; ++i) { for (size_t i = 0; i < kGroups; ++i) {
groups[i] = kRows / kGroups; groups[i] = kRows / kGroups;
} }
m->SetInfo("group", groups.data(), DataType::kUInt32, kGroups); m->SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
HistogramCuts weighted_cuts = DeviceSketch(0, m.get(), kBins, 0); HistogramCuts weighted_cuts = DeviceSketch(&ctx, m.get(), kBins, 0);
// sketch with no weight
h_weights.clear(); h_weights.clear();
HistogramCuts cuts = DeviceSketch(0, m.get(), kBins, 0); HistogramCuts cuts = DeviceSketch(&ctx, m.get(), kBins, 0);
ASSERT_EQ(cuts.Values().size(), weighted_cuts.Values().size()); ASSERT_EQ(cuts.Values().size(), weighted_cuts.Values().size());
ASSERT_EQ(cuts.MinValues().size(), weighted_cuts.MinValues().size()); ASSERT_EQ(cuts.MinValues().size(), weighted_cuts.MinValues().size());
@ -723,9 +743,10 @@ void TestAdapterSketchFromWeights(bool with_group) {
ASSERT_EQ(cuts.Ptrs().size(), kCols + 1); ASSERT_EQ(cuts.Ptrs().size(), kCols + 1);
ValidateCuts(cuts, dmat.get(), kBins); ValidateCuts(cuts, dmat.get(), kBins);
auto cuda_ctx = MakeCUDACtx(0);
if (with_group) { if (with_group) {
dmat->Info().weights_ = decltype(dmat->Info().weights_)(); // remove weight dmat->Info().weights_ = decltype(dmat->Info().weights_)(); // remove weight
HistogramCuts non_weighted = DeviceSketch(0, dmat.get(), kBins, 0); HistogramCuts non_weighted = DeviceSketch(&cuda_ctx, dmat.get(), kBins, 0);
for (size_t i = 0; i < cuts.Values().size(); ++i) { for (size_t i = 0; i < cuts.Values().size(); ++i) {
ASSERT_EQ(cuts.Values()[i], non_weighted.Values()[i]); ASSERT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
} }
@ -760,5 +781,156 @@ TEST(HistUtil, AdapterSketchFromWeights) {
TestAdapterSketchFromWeights(false); TestAdapterSketchFromWeights(false);
TestAdapterSketchFromWeights(true); TestAdapterSketchFromWeights(true);
} }
} // namespace common
} // namespace xgboost namespace {
class DeviceSketchWithHessianTest
: public ::testing::TestWithParam<std::tuple<bool, bst_row_t, bst_bin_t>> {
bst_feature_t n_features_ = 5;
bst_group_t n_groups_{3};
auto GenerateHessian(Context const* ctx, bst_row_t n_samples) const {
HostDeviceVector<float> hessian;
auto& h_hess = hessian.HostVector();
h_hess = GenerateRandomWeights(n_samples);
std::mt19937 rng(0);
std::shuffle(h_hess.begin(), h_hess.end(), rng);
hessian.SetDevice(ctx->Device());
return hessian;
}
void CheckReg(Context const* ctx, std::shared_ptr<DMatrix> p_fmat, bst_bin_t n_bins,
HostDeviceVector<float> const& hessian, std::vector<float> const& w,
std::size_t n_elements) const {
auto const& h_hess = hessian.ConstHostVector();
{
auto& h_weight = p_fmat->Info().weights_.HostVector();
h_weight = w;
}
HistogramCuts cuts_hess =
DeviceSketchWithHessian(ctx, p_fmat.get(), n_bins, hessian.ConstDeviceSpan(), n_elements);
ValidateCuts(cuts_hess, p_fmat.get(), n_bins);
// merge hessian
{
auto& h_weight = p_fmat->Info().weights_.HostVector();
ASSERT_EQ(h_weight.size(), h_hess.size());
for (std::size_t i = 0; i < h_weight.size(); ++i) {
h_weight[i] = w[i] * h_hess[i];
}
}
HistogramCuts cuts_wh = DeviceSketch(ctx, p_fmat.get(), n_bins, n_elements);
ValidateCuts(cuts_wh, p_fmat.get(), n_bins);
ASSERT_EQ(cuts_hess.Values().size(), cuts_wh.Values().size());
for (std::size_t i = 0; i < cuts_hess.Values().size(); ++i) {
ASSERT_NEAR(cuts_wh.Values()[i], cuts_hess.Values()[i], kRtEps);
}
p_fmat->Info().weights_.HostVector() = w;
}
protected:
Context ctx_ = MakeCUDACtx(0);
void TestLTR(Context const* ctx, bst_row_t n_samples, bst_bin_t n_bins,
std::size_t n_elements) const {
auto x = GenerateRandom(n_samples, n_features_);
std::vector<bst_group_t> gptr;
gptr.resize(n_groups_ + 1, 0);
gptr[1] = n_samples / n_groups_;
gptr[2] = n_samples / n_groups_ + gptr[1];
gptr.back() = n_samples;
auto hessian = this->GenerateHessian(ctx, n_samples);
auto const& h_hess = hessian.ConstHostVector();
auto p_fmat = GetDMatrixFromData(x, n_samples, n_features_);
p_fmat->Info().group_ptr_ = gptr;
// test with constant group weight
std::vector<float> w(n_groups_, 1.0f);
p_fmat->Info().weights_.HostVector() = w;
HistogramCuts cuts_hess =
DeviceSketchWithHessian(ctx, p_fmat.get(), n_bins, hessian.ConstDeviceSpan(), n_elements);
// make validation easier by converting it into sample weight.
p_fmat->Info().weights_.HostVector() = h_hess;
p_fmat->Info().group_ptr_.clear();
ValidateCuts(cuts_hess, p_fmat.get(), n_bins);
// restore ltr properties
p_fmat->Info().weights_.HostVector() = w;
p_fmat->Info().group_ptr_ = gptr;
// test with random group weight
w = GenerateRandomWeights(n_groups_);
p_fmat->Info().weights_.HostVector() = w;
cuts_hess =
DeviceSketchWithHessian(ctx, p_fmat.get(), n_bins, hessian.ConstDeviceSpan(), n_elements);
// make validation easier by converting it into sample weight.
p_fmat->Info().weights_.HostVector() = h_hess;
p_fmat->Info().group_ptr_.clear();
ValidateCuts(cuts_hess, p_fmat.get(), n_bins);
// merge hessian with sample weight
p_fmat->Info().weights_.Resize(n_samples);
p_fmat->Info().group_ptr_.clear();
for (std::size_t i = 0; i < h_hess.size(); ++i) {
auto gidx = dh::SegmentId(Span{gptr.data(), gptr.size()}, i);
p_fmat->Info().weights_.HostVector()[i] = w[gidx] * h_hess[i];
}
auto cuts = DeviceSketch(ctx, p_fmat.get(), n_bins, n_elements);
ValidateCuts(cuts, p_fmat.get(), n_bins);
ASSERT_EQ(cuts.Values().size(), cuts_hess.Values().size());
for (std::size_t i = 0; i < cuts.Values().size(); ++i) {
EXPECT_NEAR(cuts.Values()[i], cuts_hess.Values()[i], 1e-4f);
}
}
void TestRegression(Context const* ctx, bst_row_t n_samples, bst_bin_t n_bins,
std::size_t n_elements) const {
auto x = GenerateRandom(n_samples, n_features_);
auto p_fmat = GetDMatrixFromData(x, n_samples, n_features_);
std::vector<float> w = GenerateRandomWeights(n_samples);
auto hessian = this->GenerateHessian(ctx, n_samples);
this->CheckReg(ctx, p_fmat, n_bins, hessian, w, n_elements);
}
};
auto MakeParamsForTest() {
std::vector<bst_row_t> sizes = {1, 2, 256, 512, 1000, 1500};
std::vector<bst_bin_t> bin_sizes = {2, 16, 256, 512};
std::vector<std::tuple<bool, bst_row_t, bst_bin_t>> configs;
for (auto n_samples : sizes) {
for (auto n_bins : bin_sizes) {
configs.emplace_back(true, n_samples, n_bins);
configs.emplace_back(false, n_samples, n_bins);
}
}
return configs;
}
} // namespace
TEST_P(DeviceSketchWithHessianTest, DeviceSketchWithHessian) {
auto param = GetParam();
auto n_samples = std::get<1>(param);
auto n_bins = std::get<2>(param);
if (std::get<0>(param)) {
this->TestLTR(&ctx_, n_samples, n_bins, 0);
this->TestLTR(&ctx_, n_samples, n_bins, 512);
} else {
this->TestRegression(&ctx_, n_samples, n_bins, 0);
this->TestRegression(&ctx_, n_samples, n_bins, 512);
}
}
INSTANTIATE_TEST_SUITE_P(
HistUtil, DeviceSketchWithHessianTest, ::testing::ValuesIn(MakeParamsForTest()),
[](::testing::TestParamInfo<DeviceSketchWithHessianTest::ParamType> const& info) {
auto task = std::get<0>(info.param) ? "ltr" : "reg";
auto n_samples = std::to_string(std::get<1>(info.param));
auto n_bins = std::to_string(std::get<2>(info.param));
return std::string{task} + "_" + n_samples + "_" + n_bins;
});
} // namespace xgboost::common

View File

@ -1,9 +1,14 @@
/**
* Copyright 2020-2023, XGBoost contributors
*/
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "test_quantile.h"
#include "../helpers.h"
#include "../../../src/collective/communicator-inl.cuh" #include "../../../src/collective/communicator-inl.cuh"
#include "../../../src/common/hist_util.cuh" #include "../../../src/common/hist_util.cuh"
#include "../../../src/common/quantile.cuh" #include "../../../src/common/quantile.cuh"
#include "../../../src/data/device_adapter.cuh" // CupyAdapter
#include "../helpers.h"
#include "test_quantile.h"
namespace xgboost { namespace xgboost {
namespace { namespace {
@ -437,13 +442,13 @@ void TestColumnSplitBasic() {
}()}; }()};
// Generate cuts for distributed environment. // Generate cuts for distributed environment.
auto const device = rank; auto ctx = MakeCUDACtx(rank);
HistogramCuts distributed_cuts = common::DeviceSketch(device, m.get(), kBins); HistogramCuts distributed_cuts = common::DeviceSketch(&ctx, m.get(), kBins);
// Generate cuts for single node environment // Generate cuts for single node environment
collective::Finalize(); collective::Finalize();
CHECK_EQ(collective::GetWorldSize(), 1); CHECK_EQ(collective::GetWorldSize(), 1);
HistogramCuts single_node_cuts = common::DeviceSketch(device, m.get(), kBins); HistogramCuts single_node_cuts = common::DeviceSketch(&ctx, m.get(), kBins);
auto const& sptrs = single_node_cuts.Ptrs(); auto const& sptrs = single_node_cuts.Ptrs();
auto const& dptrs = distributed_cuts.Ptrs(); auto const& dptrs = distributed_cuts.Ptrs();