Implement sketching with Hessian on GPU. (#9399)

- Prepare for implementing approx on GPU.
- Unify the code path between weighted and uniform sketching on DMatrix.
This commit is contained in:
Jiaming Yuan
2023-07-24 15:43:03 +08:00
committed by GitHub
parent 851cba931e
commit a196443a07
14 changed files with 446 additions and 230 deletions

View File

@@ -12,8 +12,8 @@
#include "../data/gradient_index.h" // for GHistIndexMatrix
#include "quantile.h"
#include "xgboost/base.h"
#include "xgboost/context.h" // Context
#include "xgboost/data.h" // SparsePage, SortedCSCPage
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for SparsePage, SortedCSCPage
#if defined(XGBOOST_MM_PREFETCH_PRESENT)
#include <xmmintrin.h>
@@ -30,7 +30,7 @@ HistogramCuts::HistogramCuts() {
}
HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins, bool use_sorted,
Span<float> const hessian) {
Span<float const> hessian) {
HistogramCuts out;
auto const &info = m->Info();
auto n_threads = ctx->Threads();

View File

@@ -19,14 +19,13 @@
#include <vector>
#include "categorical.h"
#include "cuda_context.cuh" // for CUDAContext
#include "device_helpers.cuh"
#include "hist_util.cuh"
#include "hist_util.h"
#include "math.h" // NOLINT
#include "quantile.h"
#include "xgboost/host_device_vector.h"
namespace xgboost::common {
constexpr float SketchContainer::kFactor;
@@ -109,22 +108,19 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_row_t num_ro
return std::min(sketch_batch_num_elements, kIntMax);
}
void SortByWeight(dh::device_vector<float>* weights,
dh::device_vector<Entry>* sorted_entries) {
void SortByWeight(dh::device_vector<float>* weights, dh::device_vector<Entry>* sorted_entries) {
// Sort both entries and wegihts.
dh::XGBDeviceAllocator<char> alloc;
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(),
sorted_entries->end(), weights->begin(),
detail::EntryCompareOp());
CHECK_EQ(weights->size(), sorted_entries->size());
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(), sorted_entries->end(),
weights->begin(), detail::EntryCompareOp());
// Scan weights
dh::XGBCachingDeviceAllocator<char> caching;
thrust::inclusive_scan_by_key(thrust::cuda::par(caching),
sorted_entries->begin(), sorted_entries->end(),
weights->begin(), weights->begin(),
[=] __device__(const Entry& a, const Entry& b) {
return a.index == b.index;
});
thrust::inclusive_scan_by_key(
thrust::cuda::par(caching), sorted_entries->begin(), sorted_entries->end(), weights->begin(),
weights->begin(),
[=] __device__(const Entry& a, const Entry& b) { return a.index == b.index; });
}
void RemoveDuplicatedCategories(int32_t device, MetaInfo const& info, Span<bst_row_t> d_cuts_ptr,
@@ -200,159 +196,170 @@ void RemoveDuplicatedCategories(int32_t device, MetaInfo const& info, Span<bst_r
}
} // namespace detail
void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
size_t begin, size_t end, SketchContainer *sketch_container,
int num_cuts_per_feature, size_t num_columns) {
dh::XGBCachingDeviceAllocator<char> alloc;
void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo const& info,
std::size_t begin, std::size_t end,
SketchContainer* sketch_container, // <- output sketch
int num_cuts_per_feature, common::Span<float const> sample_weight) {
dh::device_vector<Entry> sorted_entries;
if (page.data.DeviceCanRead()) {
const auto& device_data = page.data.ConstDevicePointer();
sorted_entries = dh::device_vector<Entry>(device_data + begin, device_data + end);
// direct copy if data is already on device
auto const& d_data = page.data.ConstDevicePointer();
sorted_entries = dh::device_vector<Entry>(d_data + begin, d_data + end);
} else {
const auto& host_data = page.data.ConstHostVector();
sorted_entries = dh::device_vector<Entry>(host_data.begin() + begin,
host_data.begin() + end);
const auto& h_data = page.data.ConstHostVector();
sorted_entries = dh::device_vector<Entry>(h_data.begin() + begin, h_data.begin() + end);
}
bst_row_t base_rowid = page.base_rowid;
dh::device_vector<float> entry_weight;
auto cuctx = ctx->CUDACtx();
if (!sample_weight.empty()) {
// Expand sample weight into entry weight.
CHECK_EQ(sample_weight.size(), info.num_row_);
entry_weight.resize(sorted_entries.size());
auto d_temp_weight = dh::ToSpan(entry_weight);
page.offset.SetDevice(ctx->Device());
auto row_ptrs = page.offset.ConstDeviceSpan();
thrust::for_each_n(cuctx->CTP(), thrust::make_counting_iterator(0ul), entry_weight.size(),
[=] __device__(std::size_t idx) {
std::size_t element_idx = idx + begin;
std::size_t ridx = dh::SegmentId(row_ptrs, element_idx);
d_temp_weight[idx] = sample_weight[ridx + base_rowid];
});
detail::SortByWeight(&entry_weight, &sorted_entries);
} else {
thrust::sort(cuctx->CTP(), sorted_entries.begin(), sorted_entries.end(),
detail::EntryCompareOp());
}
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
dh::caching_device_vector<size_t> column_sizes_scan;
data::IsValidFunctor dummy_is_valid(std::numeric_limits<float>::quiet_NaN());
auto batch_it = dh::MakeTransformIterator<data::COOTuple>(
sorted_entries.data().get(),
[] __device__(Entry const &e) -> data::COOTuple {
return {0, e.index, e.fvalue}; // row_idx is not needed for scanning column size.
sorted_entries.data().get(), [] __device__(Entry const& e) -> data::COOTuple {
return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size.
});
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature,
detail::GetColumnSizesScan(ctx->Ordinal(), info.num_col_, num_cuts_per_feature,
IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr,
&column_sizes_scan);
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
if (sketch_container->HasCategorical()) {
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, nullptr,
auto p_weight = entry_weight.empty() ? nullptr : &entry_weight;
detail::RemoveDuplicatedCategories(ctx->Ordinal(), info, d_cuts_ptr, &sorted_entries, p_weight,
&column_sizes_scan);
}
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
// add cuts into sketches
sketch_container->Push(dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan),
d_cuts_ptr, h_cuts_ptr.back());
// Add cuts into sketches
sketch_container->Push(dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
h_cuts_ptr.back(), dh::ToSpan(entry_weight));
sorted_entries.clear();
sorted_entries.shrink_to_fit();
CHECK_EQ(sorted_entries.capacity(), 0);
CHECK_NE(cuts_ptr.Size(), 0);
}
void ProcessWeightedBatch(int device, const SparsePage& page,
MetaInfo const& info, size_t begin, size_t end,
SketchContainer* sketch_container, int num_cuts_per_feature,
size_t num_columns,
bool is_ranking, Span<bst_group_t const> d_group_ptr) {
auto weights = info.weights_.ConstDeviceSpan();
// Unify group weight, Hessian, and sample weight into sample weight.
[[nodiscard]] Span<float const> UnifyWeight(CUDAContext const* cuctx, MetaInfo const& info,
common::Span<float const> hessian,
HostDeviceVector<float>* p_out_weight) {
if (hessian.empty()) {
if (info.IsRanking() && !info.weights_.Empty()) {
common::Span<float const> group_weight = info.weights_.ConstDeviceSpan();
dh::device_vector<bst_group_t> group_ptr(info.group_ptr_);
auto d_group_ptr = dh::ToSpan(group_ptr);
CHECK_GE(d_group_ptr.size(), 2) << "Must have at least 1 group for ranking.";
auto d_weight = info.weights_.ConstDeviceSpan();
CHECK_EQ(d_weight.size(), d_group_ptr.size() - 1)
<< "Weight size should equal to number of groups.";
p_out_weight->Resize(info.num_row_);
auto d_weight_out = p_out_weight->DeviceSpan();
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
// Binary search to assign weights to each element
dh::device_vector<float> temp_weights(sorted_entries.size());
auto d_temp_weights = temp_weights.data().get();
page.offset.SetDevice(device);
auto row_ptrs = page.offset.ConstDeviceSpan();
size_t base_rowid = page.base_rowid;
if (is_ranking) {
CHECK_GE(d_group_ptr.size(), 2)
<< "Must have at least 1 group for ranking.";
CHECK_EQ(weights.size(), d_group_ptr.size() - 1)
<< "Weight size should equal to number of groups.";
dh::LaunchN(temp_weights.size(), [=] __device__(size_t idx) {
size_t element_idx = idx + begin;
size_t ridx = dh::SegmentId(row_ptrs, element_idx);
bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx + base_rowid);
d_temp_weights[idx] = weights[group_idx];
});
} else {
dh::LaunchN(temp_weights.size(), [=] __device__(size_t idx) {
size_t element_idx = idx + begin;
size_t ridx = dh::SegmentId(row_ptrs, element_idx);
d_temp_weights[idx] = weights[ridx + base_rowid];
});
}
detail::SortByWeight(&temp_weights, &sorted_entries);
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
dh::caching_device_vector<size_t> column_sizes_scan;
data::IsValidFunctor dummy_is_valid(std::numeric_limits<float>::quiet_NaN());
auto batch_it = dh::MakeTransformIterator<data::COOTuple>(
sorted_entries.data().get(),
[] __device__(Entry const &e) -> data::COOTuple {
return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size.
});
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature,
IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr,
&column_sizes_scan);
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
if (sketch_container->HasCategorical()) {
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, &temp_weights,
&column_sizes_scan);
}
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
// Extract cuts
sketch_container->Push(dh::ToSpan(sorted_entries),
dh::ToSpan(column_sizes_scan), d_cuts_ptr,
h_cuts_ptr.back(), dh::ToSpan(temp_weights));
sorted_entries.clear();
sorted_entries.shrink_to_fit();
}
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
size_t sketch_batch_num_elements) {
dmat->Info().feature_types.SetDevice(device);
dmat->Info().feature_types.ConstDevicePointer(); // pull to device early
// Configure batch size based on available memory
bool has_weights = dmat->Info().weights_.Size() > 0;
size_t num_cuts_per_feature =
detail::RequiredSampleCutsPerColumn(max_bins, dmat->Info().num_row_);
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements,
dmat->Info().num_row_,
dmat->Info().num_col_,
dmat->Info().num_nonzero_,
device, num_cuts_per_feature, has_weights);
HistogramCuts cuts;
SketchContainer sketch_container(dmat->Info().feature_types, max_bins, dmat->Info().num_col_,
dmat->Info().num_row_, device);
dmat->Info().weights_.SetDevice(device);
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
size_t batch_nnz = batch.data.Size();
auto const& info = dmat->Info();
for (auto begin = 0ull; begin < batch_nnz; begin += sketch_batch_num_elements) {
size_t end = std::min(batch_nnz, static_cast<std::size_t>(begin + sketch_batch_num_elements));
if (has_weights) {
bool is_ranking = HostSketchContainer::UseGroup(dmat->Info());
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
info.group_ptr_.cend());
ProcessWeightedBatch(
device, batch, dmat->Info(), begin, end,
&sketch_container,
num_cuts_per_feature,
dmat->Info().num_col_,
is_ranking, dh::ToSpan(groups));
} else {
ProcessBatch(device, dmat->Info(), batch, begin, end, &sketch_container,
num_cuts_per_feature, dmat->Info().num_col_);
}
thrust::for_each_n(cuctx->CTP(), thrust::make_counting_iterator(0ul), d_weight_out.size(),
[=] XGBOOST_DEVICE(std::size_t i) {
auto gidx = dh::SegmentId(d_group_ptr, i);
d_weight_out[i] = d_weight[gidx];
});
return p_out_weight->ConstDeviceSpan();
} else {
return info.weights_.ConstDeviceSpan();
}
}
sketch_container.MakeCuts(&cuts, dmat->Info().IsColumnSplit());
// sketch with hessian as weight
p_out_weight->Resize(info.num_row_);
auto d_weight_out = p_out_weight->DeviceSpan();
if (!info.weights_.Empty()) {
// merge sample weight with hessian
auto d_weight = info.weights_.ConstDeviceSpan();
if (info.IsRanking()) {
dh::device_vector<bst_group_t> group_ptr(info.group_ptr_);
CHECK_EQ(hessian.size(), d_weight_out.size());
auto d_group_ptr = dh::ToSpan(group_ptr);
CHECK_GE(d_group_ptr.size(), 2) << "Must have at least 1 group for ranking.";
CHECK_EQ(d_weight.size(), d_group_ptr.size() - 1)
<< "Weight size should equal to number of groups.";
thrust::for_each_n(cuctx->CTP(), thrust::make_counting_iterator(0ul), hessian.size(),
[=] XGBOOST_DEVICE(std::size_t i) {
d_weight_out[i] = d_weight[dh::SegmentId(d_group_ptr, i)] * hessian(i);
});
} else {
CHECK_EQ(hessian.size(), info.num_row_);
CHECK_EQ(hessian.size(), d_weight.size());
CHECK_EQ(hessian.size(), d_weight_out.size());
thrust::for_each_n(
cuctx->CTP(), thrust::make_counting_iterator(0ul), hessian.size(),
[=] XGBOOST_DEVICE(std::size_t i) { d_weight_out[i] = d_weight[i] * hessian(i); });
}
} else {
// copy hessian as weight
CHECK_EQ(d_weight_out.size(), hessian.size());
dh::safe_cuda(cudaMemcpyAsync(d_weight_out.data(), hessian.data(), hessian.size_bytes(),
cudaMemcpyDefault));
}
return d_weight_out;
}
HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin,
Span<float const> hessian,
std::size_t sketch_batch_num_elements) {
auto const& info = p_fmat->Info();
bool has_weight = !info.weights_.Empty();
info.feature_types.SetDevice(ctx->Device());
HostDeviceVector<float> weight;
weight.SetDevice(ctx->Device());
// Configure batch size based on available memory
std::size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(max_bin, info.num_row_);
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements, info.num_row_, info.num_col_, info.num_nonzero_, ctx->Ordinal(),
num_cuts_per_feature, has_weight);
CUDAContext const* cuctx = ctx->CUDACtx();
info.weights_.SetDevice(ctx->Device());
auto d_weight = UnifyWeight(cuctx, info, hessian, &weight);
HistogramCuts cuts;
SketchContainer sketch_container(info.feature_types, max_bin, info.num_col_, info.num_row_,
ctx->Ordinal());
CHECK_EQ(has_weight || !hessian.empty(), !d_weight.empty());
for (const auto& page : p_fmat->GetBatches<SparsePage>()) {
std::size_t page_nnz = page.data.Size();
for (auto begin = 0ull; begin < page_nnz; begin += sketch_batch_num_elements) {
std::size_t end =
std::min(page_nnz, static_cast<std::size_t>(begin + sketch_batch_num_elements));
ProcessWeightedBatch(ctx, page, info, begin, end, &sketch_container, num_cuts_per_feature,
d_weight);
}
}
sketch_container.MakeCuts(&cuts, p_fmat->Info().IsColumnSplit());
return cuts;
}
} // namespace xgboost::common

View File

@@ -11,14 +11,13 @@
#include <cstddef> // for size_t
#include "../data/device_adapter.cuh"
#include "../data/adapter.h" // for IsValidFunctor
#include "device_helpers.cuh"
#include "hist_util.h"
#include "quantile.cuh"
#include "timer.h"
#include "xgboost/span.h" // for IterSpan
namespace xgboost {
namespace common {
namespace xgboost::common {
namespace cuda {
/**
* copy and paste of the host version, we can't make it a __host__ __device__ function as
@@ -246,10 +245,35 @@ void RemoveDuplicatedCategories(int32_t device, MetaInfo const& info, Span<bst_r
dh::caching_device_vector<size_t>* p_column_sizes_scan);
} // namespace detail
// Compute sketch on DMatrix.
// sketch_batch_num_elements 0 means autodetect. Only modify this for testing.
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
size_t sketch_batch_num_elements = 0);
/**
* @brief Compute sketch on DMatrix with GPU and Hessian as weight.
*
* @param ctx Runtime context
* @param p_fmat Training feature matrix
* @param max_bin Maximum number of bins for each feature
* @param hessian Hessian vector.
* @param sketch_batch_num_elements 0 means autodetect. Only modify this for testing.
*
* @return Quantile cuts
*/
HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin,
Span<float const> hessian,
std::size_t sketch_batch_num_elements = 0);
/**
* @brief Compute sketch on DMatrix with GPU.
*
* @param ctx Runtime context
* @param p_fmat Training feature matrix
* @param max_bin Maximum number of bins for each feature
* @param sketch_batch_num_elements 0 means autodetect. Only modify this for testing.
*
* @return Quantile cuts
*/
inline HistogramCuts DeviceSketch(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin,
std::size_t sketch_batch_num_elements = 0) {
return DeviceSketchWithHessian(ctx, p_fmat, max_bin, {}, sketch_batch_num_elements);
}
template <typename AdapterBatch>
void ProcessSlidingWindow(AdapterBatch const &batch, MetaInfo const &info,
@@ -417,7 +441,5 @@ void AdapterDeviceSketch(Batch batch, int num_bins,
}
}
}
} // namespace common
} // namespace xgboost
} // namespace xgboost::common
#endif // COMMON_HIST_UTIL_CUH_

View File

@@ -172,7 +172,7 @@ class HistogramCuts {
* but consumes more memory.
*/
HistogramCuts SketchOnDMatrix(Context const* ctx, DMatrix* m, bst_bin_t max_bins,
bool use_sorted = false, Span<float> const hessian = {});
bool use_sorted = false, Span<float const> hessian = {});
enum BinTypeSize : uint8_t {
kUint8BinsTypeSize = 1,

View File

@@ -168,6 +168,9 @@ bool HostDeviceVector<T>::DeviceCanWrite() const {
template <typename T>
void HostDeviceVector<T>::SetDevice(int) const {}
template <typename T>
void HostDeviceVector<T>::SetDevice(DeviceOrd) const {}
// explicit instantiations are required, as HostDeviceVector isn't header-only
template class HostDeviceVector<bst_float>;
template class HostDeviceVector<double>;

View File

@@ -394,6 +394,11 @@ void HostDeviceVector<T>::SetDevice(int device) const {
impl_->SetDevice(device);
}
template <typename T>
void HostDeviceVector<T>::SetDevice(DeviceOrd device) const {
impl_->SetDevice(device.ordinal);
}
template <typename T>
void HostDeviceVector<T>::Resize(size_t new_size, T v) {
impl_->Resize(new_size, v);

View File

@@ -131,7 +131,7 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchP
monitor_.Start("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
row_stride = GetRowStride(dmat);
cuts_ = common::DeviceSketch(ctx->gpu_id, dmat, param.max_bin);
cuts_ = common::DeviceSketch(ctx, dmat, param.max_bin);
monitor_.Stop("Quantiles");
monitor_.Start("InitCompressedData");

View File

@@ -21,7 +21,7 @@ GHistIndexMatrix::GHistIndexMatrix() : columns_{std::make_unique<common::ColumnM
GHistIndexMatrix::GHistIndexMatrix(Context const *ctx, DMatrix *p_fmat, bst_bin_t max_bins_per_feat,
double sparse_thresh, bool sorted_sketch,
common::Span<float> hess)
common::Span<float const> hess)
: max_numeric_bins_per_feat{max_bins_per_feat} {
CHECK(p_fmat->SingleColBlock());
// We use sorted sketching for approx tree method since it's more efficient in

View File

@@ -160,7 +160,7 @@ class GHistIndexMatrix {
* \brief Constrcutor for SimpleDMatrix.
*/
GHistIndexMatrix(Context const* ctx, DMatrix* x, bst_bin_t max_bins_per_feat,
double sparse_thresh, bool sorted_sketch, common::Span<float> hess = {});
double sparse_thresh, bool sorted_sketch, common::Span<float const> hess = {});
/**
* \brief Constructor for Iterative DMatrix. Initialize basic information and prepare
* for push batch.

View File

@@ -25,8 +25,8 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
cache_info_.erase(id);
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
std::unique_ptr<common::HistogramCuts> cuts;
cuts = std::make_unique<common::HistogramCuts>(
common::DeviceSketch(ctx->gpu_id, this, param.max_bin, 0));
cuts =
std::make_unique<common::HistogramCuts>(common::DeviceSketch(ctx, this, param.max_bin, 0));
this->InitializeSparsePage(ctx); // reset after use.
row_stride = GetRowStride(this);