Implement sketching with Hessian on GPU. (#9399)
- Prepare for implementing approx on GPU. - Unify the code path between weighted and uniform sketching on DMatrix.
This commit is contained in:
parent
851cba931e
commit
a196443a07
@ -185,10 +185,10 @@ class MetaInfo {
|
||||
return data_split_mode == DataSplitMode::kRow;
|
||||
}
|
||||
|
||||
/*! \brief Whether the data is split column-wise. */
|
||||
bool IsColumnSplit() const {
|
||||
return data_split_mode == DataSplitMode::kCol;
|
||||
}
|
||||
/** @brief Whether the data is split column-wise. */
|
||||
bool IsColumnSplit() const { return data_split_mode == DataSplitMode::kCol; }
|
||||
/** @brief Whether this is a learning to rank data. */
|
||||
bool IsRanking() const { return !group_ptr_.empty(); }
|
||||
|
||||
/*!
|
||||
* \brief A convenient method to check if we are doing vertical federated learning, which requires
|
||||
@ -249,7 +249,7 @@ struct BatchParam {
|
||||
/**
|
||||
* \brief Hessian, used for sketching with future approx implementation.
|
||||
*/
|
||||
common::Span<float> hess;
|
||||
common::Span<float const> hess;
|
||||
/**
|
||||
* \brief Whether should we force DMatrix to regenerate the batch. Only used for
|
||||
* GHistIndex.
|
||||
@ -279,7 +279,7 @@ struct BatchParam {
|
||||
* Get batch with sketch weighted by hessian. The batch will be regenerated if the
|
||||
* span is changed, so caller should keep the span for each iteration.
|
||||
*/
|
||||
BatchParam(bst_bin_t max_bin, common::Span<float> hessian, bool regenerate)
|
||||
BatchParam(bst_bin_t max_bin, common::Span<float const> hessian, bool regenerate)
|
||||
: max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
|
||||
|
||||
[[nodiscard]] bool ParamNotEqual(BatchParam const& other) const {
|
||||
|
||||
@ -49,11 +49,12 @@
|
||||
#ifndef XGBOOST_HOST_DEVICE_VECTOR_H_
|
||||
#define XGBOOST_HOST_DEVICE_VECTOR_H_
|
||||
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
#include <type_traits>
|
||||
#include <xgboost/context.h> // for DeviceOrd
|
||||
#include <xgboost/span.h> // for Span
|
||||
|
||||
#include "span.h"
|
||||
#include <initializer_list>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@ -133,6 +134,7 @@ class HostDeviceVector {
|
||||
GPUAccess DeviceAccess() const;
|
||||
|
||||
void SetDevice(int device) const;
|
||||
void SetDevice(DeviceOrd device) const;
|
||||
|
||||
void Resize(size_t new_size, T v = T());
|
||||
|
||||
|
||||
@ -12,8 +12,8 @@
|
||||
#include "../data/gradient_index.h" // for GHistIndexMatrix
|
||||
#include "quantile.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/data.h" // SparsePage, SortedCSCPage
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h" // for SparsePage, SortedCSCPage
|
||||
|
||||
#if defined(XGBOOST_MM_PREFETCH_PRESENT)
|
||||
#include <xmmintrin.h>
|
||||
@ -30,7 +30,7 @@ HistogramCuts::HistogramCuts() {
|
||||
}
|
||||
|
||||
HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins, bool use_sorted,
|
||||
Span<float> const hessian) {
|
||||
Span<float const> hessian) {
|
||||
HistogramCuts out;
|
||||
auto const &info = m->Info();
|
||||
auto n_threads = ctx->Threads();
|
||||
|
||||
@ -19,14 +19,13 @@
|
||||
#include <vector>
|
||||
|
||||
#include "categorical.h"
|
||||
#include "cuda_context.cuh" // for CUDAContext
|
||||
#include "device_helpers.cuh"
|
||||
#include "hist_util.cuh"
|
||||
#include "hist_util.h"
|
||||
#include "math.h" // NOLINT
|
||||
#include "quantile.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
|
||||
namespace xgboost::common {
|
||||
constexpr float SketchContainer::kFactor;
|
||||
|
||||
@ -109,22 +108,19 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_row_t num_ro
|
||||
return std::min(sketch_batch_num_elements, kIntMax);
|
||||
}
|
||||
|
||||
void SortByWeight(dh::device_vector<float>* weights,
|
||||
dh::device_vector<Entry>* sorted_entries) {
|
||||
void SortByWeight(dh::device_vector<float>* weights, dh::device_vector<Entry>* sorted_entries) {
|
||||
// Sort both entries and wegihts.
|
||||
dh::XGBDeviceAllocator<char> alloc;
|
||||
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(),
|
||||
sorted_entries->end(), weights->begin(),
|
||||
detail::EntryCompareOp());
|
||||
CHECK_EQ(weights->size(), sorted_entries->size());
|
||||
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(), sorted_entries->end(),
|
||||
weights->begin(), detail::EntryCompareOp());
|
||||
|
||||
// Scan weights
|
||||
dh::XGBCachingDeviceAllocator<char> caching;
|
||||
thrust::inclusive_scan_by_key(thrust::cuda::par(caching),
|
||||
sorted_entries->begin(), sorted_entries->end(),
|
||||
weights->begin(), weights->begin(),
|
||||
[=] __device__(const Entry& a, const Entry& b) {
|
||||
return a.index == b.index;
|
||||
});
|
||||
thrust::inclusive_scan_by_key(
|
||||
thrust::cuda::par(caching), sorted_entries->begin(), sorted_entries->end(), weights->begin(),
|
||||
weights->begin(),
|
||||
[=] __device__(const Entry& a, const Entry& b) { return a.index == b.index; });
|
||||
}
|
||||
|
||||
void RemoveDuplicatedCategories(int32_t device, MetaInfo const& info, Span<bst_row_t> d_cuts_ptr,
|
||||
@ -200,159 +196,170 @@ void RemoveDuplicatedCategories(int32_t device, MetaInfo const& info, Span<bst_r
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
|
||||
size_t begin, size_t end, SketchContainer *sketch_container,
|
||||
int num_cuts_per_feature, size_t num_columns) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo const& info,
|
||||
std::size_t begin, std::size_t end,
|
||||
SketchContainer* sketch_container, // <- output sketch
|
||||
int num_cuts_per_feature, common::Span<float const> sample_weight) {
|
||||
dh::device_vector<Entry> sorted_entries;
|
||||
if (page.data.DeviceCanRead()) {
|
||||
const auto& device_data = page.data.ConstDevicePointer();
|
||||
sorted_entries = dh::device_vector<Entry>(device_data + begin, device_data + end);
|
||||
// direct copy if data is already on device
|
||||
auto const& d_data = page.data.ConstDevicePointer();
|
||||
sorted_entries = dh::device_vector<Entry>(d_data + begin, d_data + end);
|
||||
} else {
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
sorted_entries = dh::device_vector<Entry>(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
const auto& h_data = page.data.ConstHostVector();
|
||||
sorted_entries = dh::device_vector<Entry>(h_data.begin() + begin, h_data.begin() + end);
|
||||
}
|
||||
|
||||
bst_row_t base_rowid = page.base_rowid;
|
||||
|
||||
dh::device_vector<float> entry_weight;
|
||||
auto cuctx = ctx->CUDACtx();
|
||||
if (!sample_weight.empty()) {
|
||||
// Expand sample weight into entry weight.
|
||||
CHECK_EQ(sample_weight.size(), info.num_row_);
|
||||
entry_weight.resize(sorted_entries.size());
|
||||
auto d_temp_weight = dh::ToSpan(entry_weight);
|
||||
page.offset.SetDevice(ctx->Device());
|
||||
auto row_ptrs = page.offset.ConstDeviceSpan();
|
||||
thrust::for_each_n(cuctx->CTP(), thrust::make_counting_iterator(0ul), entry_weight.size(),
|
||||
[=] __device__(std::size_t idx) {
|
||||
std::size_t element_idx = idx + begin;
|
||||
std::size_t ridx = dh::SegmentId(row_ptrs, element_idx);
|
||||
d_temp_weight[idx] = sample_weight[ridx + base_rowid];
|
||||
});
|
||||
detail::SortByWeight(&entry_weight, &sorted_entries);
|
||||
} else {
|
||||
thrust::sort(cuctx->CTP(), sorted_entries.begin(), sorted_entries.end(),
|
||||
detail::EntryCompareOp());
|
||||
}
|
||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
sorted_entries.end(), detail::EntryCompareOp());
|
||||
|
||||
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
|
||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||
data::IsValidFunctor dummy_is_valid(std::numeric_limits<float>::quiet_NaN());
|
||||
auto batch_it = dh::MakeTransformIterator<data::COOTuple>(
|
||||
sorted_entries.data().get(),
|
||||
[] __device__(Entry const &e) -> data::COOTuple {
|
||||
return {0, e.index, e.fvalue}; // row_idx is not needed for scanning column size.
|
||||
sorted_entries.data().get(), [] __device__(Entry const& e) -> data::COOTuple {
|
||||
return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size.
|
||||
});
|
||||
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature,
|
||||
detail::GetColumnSizesScan(ctx->Ordinal(), info.num_col_, num_cuts_per_feature,
|
||||
IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr,
|
||||
&column_sizes_scan);
|
||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||
|
||||
if (sketch_container->HasCategorical()) {
|
||||
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, nullptr,
|
||||
auto p_weight = entry_weight.empty() ? nullptr : &entry_weight;
|
||||
detail::RemoveDuplicatedCategories(ctx->Ordinal(), info, d_cuts_ptr, &sorted_entries, p_weight,
|
||||
&column_sizes_scan);
|
||||
}
|
||||
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
|
||||
|
||||
// add cuts into sketches
|
||||
sketch_container->Push(dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan),
|
||||
d_cuts_ptr, h_cuts_ptr.back());
|
||||
// Add cuts into sketches
|
||||
sketch_container->Push(dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
|
||||
h_cuts_ptr.back(), dh::ToSpan(entry_weight));
|
||||
|
||||
sorted_entries.clear();
|
||||
sorted_entries.shrink_to_fit();
|
||||
CHECK_EQ(sorted_entries.capacity(), 0);
|
||||
CHECK_NE(cuts_ptr.Size(), 0);
|
||||
}
|
||||
|
||||
void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
MetaInfo const& info, size_t begin, size_t end,
|
||||
SketchContainer* sketch_container, int num_cuts_per_feature,
|
||||
size_t num_columns,
|
||||
bool is_ranking, Span<bst_group_t const> d_group_ptr) {
|
||||
auto weights = info.weights_.ConstDeviceSpan();
|
||||
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
|
||||
// Binary search to assign weights to each element
|
||||
dh::device_vector<float> temp_weights(sorted_entries.size());
|
||||
auto d_temp_weights = temp_weights.data().get();
|
||||
page.offset.SetDevice(device);
|
||||
auto row_ptrs = page.offset.ConstDeviceSpan();
|
||||
size_t base_rowid = page.base_rowid;
|
||||
if (is_ranking) {
|
||||
CHECK_GE(d_group_ptr.size(), 2)
|
||||
<< "Must have at least 1 group for ranking.";
|
||||
CHECK_EQ(weights.size(), d_group_ptr.size() - 1)
|
||||
// Unify group weight, Hessian, and sample weight into sample weight.
|
||||
[[nodiscard]] Span<float const> UnifyWeight(CUDAContext const* cuctx, MetaInfo const& info,
|
||||
common::Span<float const> hessian,
|
||||
HostDeviceVector<float>* p_out_weight) {
|
||||
if (hessian.empty()) {
|
||||
if (info.IsRanking() && !info.weights_.Empty()) {
|
||||
common::Span<float const> group_weight = info.weights_.ConstDeviceSpan();
|
||||
dh::device_vector<bst_group_t> group_ptr(info.group_ptr_);
|
||||
auto d_group_ptr = dh::ToSpan(group_ptr);
|
||||
CHECK_GE(d_group_ptr.size(), 2) << "Must have at least 1 group for ranking.";
|
||||
auto d_weight = info.weights_.ConstDeviceSpan();
|
||||
CHECK_EQ(d_weight.size(), d_group_ptr.size() - 1)
|
||||
<< "Weight size should equal to number of groups.";
|
||||
dh::LaunchN(temp_weights.size(), [=] __device__(size_t idx) {
|
||||
size_t element_idx = idx + begin;
|
||||
size_t ridx = dh::SegmentId(row_ptrs, element_idx);
|
||||
bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx + base_rowid);
|
||||
d_temp_weights[idx] = weights[group_idx];
|
||||
p_out_weight->Resize(info.num_row_);
|
||||
auto d_weight_out = p_out_weight->DeviceSpan();
|
||||
|
||||
thrust::for_each_n(cuctx->CTP(), thrust::make_counting_iterator(0ul), d_weight_out.size(),
|
||||
[=] XGBOOST_DEVICE(std::size_t i) {
|
||||
auto gidx = dh::SegmentId(d_group_ptr, i);
|
||||
d_weight_out[i] = d_weight[gidx];
|
||||
});
|
||||
return p_out_weight->ConstDeviceSpan();
|
||||
} else {
|
||||
return info.weights_.ConstDeviceSpan();
|
||||
}
|
||||
}
|
||||
|
||||
// sketch with hessian as weight
|
||||
p_out_weight->Resize(info.num_row_);
|
||||
auto d_weight_out = p_out_weight->DeviceSpan();
|
||||
if (!info.weights_.Empty()) {
|
||||
// merge sample weight with hessian
|
||||
auto d_weight = info.weights_.ConstDeviceSpan();
|
||||
if (info.IsRanking()) {
|
||||
dh::device_vector<bst_group_t> group_ptr(info.group_ptr_);
|
||||
CHECK_EQ(hessian.size(), d_weight_out.size());
|
||||
auto d_group_ptr = dh::ToSpan(group_ptr);
|
||||
CHECK_GE(d_group_ptr.size(), 2) << "Must have at least 1 group for ranking.";
|
||||
CHECK_EQ(d_weight.size(), d_group_ptr.size() - 1)
|
||||
<< "Weight size should equal to number of groups.";
|
||||
thrust::for_each_n(cuctx->CTP(), thrust::make_counting_iterator(0ul), hessian.size(),
|
||||
[=] XGBOOST_DEVICE(std::size_t i) {
|
||||
d_weight_out[i] = d_weight[dh::SegmentId(d_group_ptr, i)] * hessian(i);
|
||||
});
|
||||
} else {
|
||||
dh::LaunchN(temp_weights.size(), [=] __device__(size_t idx) {
|
||||
size_t element_idx = idx + begin;
|
||||
size_t ridx = dh::SegmentId(row_ptrs, element_idx);
|
||||
d_temp_weights[idx] = weights[ridx + base_rowid];
|
||||
});
|
||||
CHECK_EQ(hessian.size(), info.num_row_);
|
||||
CHECK_EQ(hessian.size(), d_weight.size());
|
||||
CHECK_EQ(hessian.size(), d_weight_out.size());
|
||||
thrust::for_each_n(
|
||||
cuctx->CTP(), thrust::make_counting_iterator(0ul), hessian.size(),
|
||||
[=] XGBOOST_DEVICE(std::size_t i) { d_weight_out[i] = d_weight[i] * hessian(i); });
|
||||
}
|
||||
detail::SortByWeight(&temp_weights, &sorted_entries);
|
||||
|
||||
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
|
||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||
data::IsValidFunctor dummy_is_valid(std::numeric_limits<float>::quiet_NaN());
|
||||
auto batch_it = dh::MakeTransformIterator<data::COOTuple>(
|
||||
sorted_entries.data().get(),
|
||||
[] __device__(Entry const &e) -> data::COOTuple {
|
||||
return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size.
|
||||
});
|
||||
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature,
|
||||
IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr,
|
||||
&column_sizes_scan);
|
||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||
if (sketch_container->HasCategorical()) {
|
||||
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, &temp_weights,
|
||||
&column_sizes_scan);
|
||||
} else {
|
||||
// copy hessian as weight
|
||||
CHECK_EQ(d_weight_out.size(), hessian.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(d_weight_out.data(), hessian.data(), hessian.size_bytes(),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
|
||||
// Extract cuts
|
||||
sketch_container->Push(dh::ToSpan(sorted_entries),
|
||||
dh::ToSpan(column_sizes_scan), d_cuts_ptr,
|
||||
h_cuts_ptr.back(), dh::ToSpan(temp_weights));
|
||||
sorted_entries.clear();
|
||||
sorted_entries.shrink_to_fit();
|
||||
return d_weight_out;
|
||||
}
|
||||
|
||||
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
size_t sketch_batch_num_elements) {
|
||||
dmat->Info().feature_types.SetDevice(device);
|
||||
dmat->Info().feature_types.ConstDevicePointer(); // pull to device early
|
||||
HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin,
|
||||
Span<float const> hessian,
|
||||
std::size_t sketch_batch_num_elements) {
|
||||
auto const& info = p_fmat->Info();
|
||||
bool has_weight = !info.weights_.Empty();
|
||||
info.feature_types.SetDevice(ctx->Device());
|
||||
|
||||
HostDeviceVector<float> weight;
|
||||
weight.SetDevice(ctx->Device());
|
||||
|
||||
// Configure batch size based on available memory
|
||||
bool has_weights = dmat->Info().weights_.Size() > 0;
|
||||
size_t num_cuts_per_feature =
|
||||
detail::RequiredSampleCutsPerColumn(max_bins, dmat->Info().num_row_);
|
||||
std::size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(max_bin, info.num_row_);
|
||||
sketch_batch_num_elements = detail::SketchBatchNumElements(
|
||||
sketch_batch_num_elements,
|
||||
dmat->Info().num_row_,
|
||||
dmat->Info().num_col_,
|
||||
dmat->Info().num_nonzero_,
|
||||
device, num_cuts_per_feature, has_weights);
|
||||
sketch_batch_num_elements, info.num_row_, info.num_col_, info.num_nonzero_, ctx->Ordinal(),
|
||||
num_cuts_per_feature, has_weight);
|
||||
|
||||
CUDAContext const* cuctx = ctx->CUDACtx();
|
||||
|
||||
info.weights_.SetDevice(ctx->Device());
|
||||
auto d_weight = UnifyWeight(cuctx, info, hessian, &weight);
|
||||
|
||||
HistogramCuts cuts;
|
||||
SketchContainer sketch_container(dmat->Info().feature_types, max_bins, dmat->Info().num_col_,
|
||||
dmat->Info().num_row_, device);
|
||||
SketchContainer sketch_container(info.feature_types, max_bin, info.num_col_, info.num_row_,
|
||||
ctx->Ordinal());
|
||||
CHECK_EQ(has_weight || !hessian.empty(), !d_weight.empty());
|
||||
for (const auto& page : p_fmat->GetBatches<SparsePage>()) {
|
||||
std::size_t page_nnz = page.data.Size();
|
||||
for (auto begin = 0ull; begin < page_nnz; begin += sketch_batch_num_elements) {
|
||||
std::size_t end =
|
||||
std::min(page_nnz, static_cast<std::size_t>(begin + sketch_batch_num_elements));
|
||||
ProcessWeightedBatch(ctx, page, info, begin, end, &sketch_container, num_cuts_per_feature,
|
||||
d_weight);
|
||||
}
|
||||
}
|
||||
|
||||
dmat->Info().weights_.SetDevice(device);
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
size_t batch_nnz = batch.data.Size();
|
||||
auto const& info = dmat->Info();
|
||||
for (auto begin = 0ull; begin < batch_nnz; begin += sketch_batch_num_elements) {
|
||||
size_t end = std::min(batch_nnz, static_cast<std::size_t>(begin + sketch_batch_num_elements));
|
||||
if (has_weights) {
|
||||
bool is_ranking = HostSketchContainer::UseGroup(dmat->Info());
|
||||
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
|
||||
info.group_ptr_.cend());
|
||||
ProcessWeightedBatch(
|
||||
device, batch, dmat->Info(), begin, end,
|
||||
&sketch_container,
|
||||
num_cuts_per_feature,
|
||||
dmat->Info().num_col_,
|
||||
is_ranking, dh::ToSpan(groups));
|
||||
} else {
|
||||
ProcessBatch(device, dmat->Info(), batch, begin, end, &sketch_container,
|
||||
num_cuts_per_feature, dmat->Info().num_col_);
|
||||
}
|
||||
}
|
||||
}
|
||||
sketch_container.MakeCuts(&cuts, dmat->Info().IsColumnSplit());
|
||||
sketch_container.MakeCuts(&cuts, p_fmat->Info().IsColumnSplit());
|
||||
return cuts;
|
||||
}
|
||||
} // namespace xgboost::common
|
||||
|
||||
@ -11,14 +11,13 @@
|
||||
|
||||
#include <cstddef> // for size_t
|
||||
|
||||
#include "../data/device_adapter.cuh"
|
||||
#include "../data/adapter.h" // for IsValidFunctor
|
||||
#include "device_helpers.cuh"
|
||||
#include "hist_util.h"
|
||||
#include "quantile.cuh"
|
||||
#include "timer.h"
|
||||
#include "xgboost/span.h" // for IterSpan
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
namespace xgboost::common {
|
||||
namespace cuda {
|
||||
/**
|
||||
* copy and paste of the host version, we can't make it a __host__ __device__ function as
|
||||
@ -246,10 +245,35 @@ void RemoveDuplicatedCategories(int32_t device, MetaInfo const& info, Span<bst_r
|
||||
dh::caching_device_vector<size_t>* p_column_sizes_scan);
|
||||
} // namespace detail
|
||||
|
||||
// Compute sketch on DMatrix.
|
||||
// sketch_batch_num_elements 0 means autodetect. Only modify this for testing.
|
||||
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
size_t sketch_batch_num_elements = 0);
|
||||
/**
|
||||
* @brief Compute sketch on DMatrix with GPU and Hessian as weight.
|
||||
*
|
||||
* @param ctx Runtime context
|
||||
* @param p_fmat Training feature matrix
|
||||
* @param max_bin Maximum number of bins for each feature
|
||||
* @param hessian Hessian vector.
|
||||
* @param sketch_batch_num_elements 0 means autodetect. Only modify this for testing.
|
||||
*
|
||||
* @return Quantile cuts
|
||||
*/
|
||||
HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin,
|
||||
Span<float const> hessian,
|
||||
std::size_t sketch_batch_num_elements = 0);
|
||||
|
||||
/**
|
||||
* @brief Compute sketch on DMatrix with GPU.
|
||||
*
|
||||
* @param ctx Runtime context
|
||||
* @param p_fmat Training feature matrix
|
||||
* @param max_bin Maximum number of bins for each feature
|
||||
* @param sketch_batch_num_elements 0 means autodetect. Only modify this for testing.
|
||||
*
|
||||
* @return Quantile cuts
|
||||
*/
|
||||
inline HistogramCuts DeviceSketch(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin,
|
||||
std::size_t sketch_batch_num_elements = 0) {
|
||||
return DeviceSketchWithHessian(ctx, p_fmat, max_bin, {}, sketch_batch_num_elements);
|
||||
}
|
||||
|
||||
template <typename AdapterBatch>
|
||||
void ProcessSlidingWindow(AdapterBatch const &batch, MetaInfo const &info,
|
||||
@ -417,7 +441,5 @@ void AdapterDeviceSketch(Batch batch, int num_bins,
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
} // namespace xgboost::common
|
||||
#endif // COMMON_HIST_UTIL_CUH_
|
||||
|
||||
@ -172,7 +172,7 @@ class HistogramCuts {
|
||||
* but consumes more memory.
|
||||
*/
|
||||
HistogramCuts SketchOnDMatrix(Context const* ctx, DMatrix* m, bst_bin_t max_bins,
|
||||
bool use_sorted = false, Span<float> const hessian = {});
|
||||
bool use_sorted = false, Span<float const> hessian = {});
|
||||
|
||||
enum BinTypeSize : uint8_t {
|
||||
kUint8BinsTypeSize = 1,
|
||||
|
||||
@ -168,6 +168,9 @@ bool HostDeviceVector<T>::DeviceCanWrite() const {
|
||||
template <typename T>
|
||||
void HostDeviceVector<T>::SetDevice(int) const {}
|
||||
|
||||
template <typename T>
|
||||
void HostDeviceVector<T>::SetDevice(DeviceOrd) const {}
|
||||
|
||||
// explicit instantiations are required, as HostDeviceVector isn't header-only
|
||||
template class HostDeviceVector<bst_float>;
|
||||
template class HostDeviceVector<double>;
|
||||
|
||||
@ -394,6 +394,11 @@ void HostDeviceVector<T>::SetDevice(int device) const {
|
||||
impl_->SetDevice(device);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void HostDeviceVector<T>::SetDevice(DeviceOrd device) const {
|
||||
impl_->SetDevice(device.ordinal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void HostDeviceVector<T>::Resize(size_t new_size, T v) {
|
||||
impl_->Resize(new_size, v);
|
||||
|
||||
@ -131,7 +131,7 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchP
|
||||
monitor_.Start("Quantiles");
|
||||
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
|
||||
row_stride = GetRowStride(dmat);
|
||||
cuts_ = common::DeviceSketch(ctx->gpu_id, dmat, param.max_bin);
|
||||
cuts_ = common::DeviceSketch(ctx, dmat, param.max_bin);
|
||||
monitor_.Stop("Quantiles");
|
||||
|
||||
monitor_.Start("InitCompressedData");
|
||||
|
||||
@ -21,7 +21,7 @@ GHistIndexMatrix::GHistIndexMatrix() : columns_{std::make_unique<common::ColumnM
|
||||
|
||||
GHistIndexMatrix::GHistIndexMatrix(Context const *ctx, DMatrix *p_fmat, bst_bin_t max_bins_per_feat,
|
||||
double sparse_thresh, bool sorted_sketch,
|
||||
common::Span<float> hess)
|
||||
common::Span<float const> hess)
|
||||
: max_numeric_bins_per_feat{max_bins_per_feat} {
|
||||
CHECK(p_fmat->SingleColBlock());
|
||||
// We use sorted sketching for approx tree method since it's more efficient in
|
||||
|
||||
@ -160,7 +160,7 @@ class GHistIndexMatrix {
|
||||
* \brief Constrcutor for SimpleDMatrix.
|
||||
*/
|
||||
GHistIndexMatrix(Context const* ctx, DMatrix* x, bst_bin_t max_bins_per_feat,
|
||||
double sparse_thresh, bool sorted_sketch, common::Span<float> hess = {});
|
||||
double sparse_thresh, bool sorted_sketch, common::Span<float const> hess = {});
|
||||
/**
|
||||
* \brief Constructor for Iterative DMatrix. Initialize basic information and prepare
|
||||
* for push batch.
|
||||
|
||||
@ -25,8 +25,8 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
|
||||
cache_info_.erase(id);
|
||||
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
||||
std::unique_ptr<common::HistogramCuts> cuts;
|
||||
cuts = std::make_unique<common::HistogramCuts>(
|
||||
common::DeviceSketch(ctx->gpu_id, this, param.max_bin, 0));
|
||||
cuts =
|
||||
std::make_unique<common::HistogramCuts>(common::DeviceSketch(ctx, this, param.max_bin, 0));
|
||||
this->InitializeSparsePage(ctx); // reset after use.
|
||||
|
||||
row_stride = GetRowStride(this);
|
||||
|
||||
@ -3,17 +3,22 @@
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <xgboost/base.h> // for bst_bin_t
|
||||
#include <xgboost/c_api.h>
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <algorithm> // for transform
|
||||
#include <cmath> // for floor
|
||||
#include <cstddef> // for size_t
|
||||
#include <limits> // for numeric_limits
|
||||
#include <string> // for string, to_string
|
||||
#include <tuple> // for tuple, make_tuple
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../include/xgboost/logging.h"
|
||||
#include "../../../src/common/device_helpers.cuh"
|
||||
#include "../../../src/common/hist_util.cuh"
|
||||
#include "../../../src/common/hist_util.h"
|
||||
#include "../../../src/common/math.h"
|
||||
#include "../../../src/data/device_adapter.cuh"
|
||||
#include "../../../src/data/simple_dmatrix.h"
|
||||
#include "../data/test_array_interface.h"
|
||||
@ -21,8 +26,7 @@
|
||||
#include "../helpers.h"
|
||||
#include "test_hist_util.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
namespace xgboost::common {
|
||||
|
||||
template <typename AdapterT>
|
||||
HistogramCuts GetHostCuts(Context const* ctx, AdapterT* adapter, int num_bins, float missing) {
|
||||
@ -32,16 +36,17 @@ HistogramCuts GetHostCuts(Context const* ctx, AdapterT* adapter, int num_bins, f
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketch) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
int num_columns = 1;
|
||||
int num_bins = 4;
|
||||
std::vector<float> x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 7.0f, -1.0f};
|
||||
int num_rows = x.size();
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
|
||||
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
auto device_cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
|
||||
|
||||
Context ctx;
|
||||
HistogramCuts host_cuts = SketchOnDMatrix(&ctx, dmat.get(), num_bins);
|
||||
Context cpu_ctx;
|
||||
HistogramCuts host_cuts = SketchOnDMatrix(&cpu_ctx, dmat.get(), num_bins);
|
||||
|
||||
EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
|
||||
EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs());
|
||||
@ -65,6 +70,7 @@ TEST(HistUtil, SketchBatchNumElements) {
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchMemory) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
int num_columns = 100;
|
||||
int num_rows = 1000;
|
||||
int num_bins = 256;
|
||||
@ -73,7 +79,7 @@ TEST(HistUtil, DeviceSketchMemory) {
|
||||
|
||||
dh::GlobalMemoryLogger().Clear();
|
||||
ConsoleLogger::Configure({{"verbosity", "3"}});
|
||||
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
auto device_cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
|
||||
|
||||
size_t bytes_required = detail::RequiredMemory(
|
||||
num_rows, num_columns, num_rows * num_columns, num_bins, false);
|
||||
@ -83,6 +89,7 @@ TEST(HistUtil, DeviceSketchMemory) {
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchWeightsMemory) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
int num_columns = 100;
|
||||
int num_rows = 1000;
|
||||
int num_bins = 256;
|
||||
@ -92,7 +99,7 @@ TEST(HistUtil, DeviceSketchWeightsMemory) {
|
||||
|
||||
dh::GlobalMemoryLogger().Clear();
|
||||
ConsoleLogger::Configure({{"verbosity", "3"}});
|
||||
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
auto device_cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
|
||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||
|
||||
size_t bytes_required = detail::RequiredMemory(
|
||||
@ -102,42 +109,43 @@ TEST(HistUtil, DeviceSketchWeightsMemory) {
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchDeterminism) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
int num_rows = 500;
|
||||
int num_columns = 5;
|
||||
int num_bins = 256;
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
auto reference_sketch = DeviceSketch(0, dmat.get(), num_bins);
|
||||
auto reference_sketch = DeviceSketch(&ctx, dmat.get(), num_bins);
|
||||
size_t constexpr kRounds{ 100 };
|
||||
for (size_t r = 0; r < kRounds; ++r) {
|
||||
auto new_sketch = DeviceSketch(0, dmat.get(), num_bins);
|
||||
auto new_sketch = DeviceSketch(&ctx, dmat.get(), num_bins);
|
||||
ASSERT_EQ(reference_sketch.Values(), new_sketch.Values());
|
||||
ASSERT_EQ(reference_sketch.MinValues(), new_sketch.MinValues());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchCategoricalAsNumeric) {
|
||||
int categorical_sizes[] = {2, 6, 8, 12};
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
auto categorical_sizes = {2, 6, 8, 12};
|
||||
int num_bins = 256;
|
||||
int sizes[] = {25, 100, 1000};
|
||||
auto sizes = {25, 100, 1000};
|
||||
for (auto n : sizes) {
|
||||
for (auto num_categories : categorical_sizes) {
|
||||
auto x = GenerateRandomCategoricalSingleColumn(n, num_categories);
|
||||
auto dmat = GetDMatrixFromData(x, n, 1);
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchCategoricalFeatures) {
|
||||
TestCategoricalSketch(1000, 256, 32, false,
|
||||
[](DMatrix *p_fmat, int32_t num_bins) {
|
||||
return DeviceSketch(0, p_fmat, num_bins);
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
TestCategoricalSketch(1000, 256, 32, false, [ctx](DMatrix* p_fmat, int32_t num_bins) {
|
||||
return DeviceSketch(&ctx, p_fmat, num_bins);
|
||||
});
|
||||
TestCategoricalSketch(1000, 256, 32, true,
|
||||
[](DMatrix *p_fmat, int32_t num_bins) {
|
||||
return DeviceSketch(0, p_fmat, num_bins);
|
||||
TestCategoricalSketch(1000, 256, 32, true, [ctx](DMatrix* p_fmat, int32_t num_bins) {
|
||||
return DeviceSketch(&ctx, p_fmat, num_bins);
|
||||
});
|
||||
}
|
||||
|
||||
@ -162,7 +170,8 @@ void TestMixedSketch() {
|
||||
m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
||||
m->Info().feature_types.HostVector().push_back(FeatureType::kNumerical);
|
||||
|
||||
auto cuts = DeviceSketch(0, m.get(), n_bins);
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
auto cuts = DeviceSketch(&ctx, m.get(), n_bins);
|
||||
ASSERT_EQ(cuts.Values().size(), n_bins + n_categories);
|
||||
}
|
||||
|
||||
@ -234,37 +243,40 @@ TEST(HistUtil, RemoveDuplicatedCategories) {
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchMultipleColumns) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
auto bin_sizes = {2, 16, 256, 512};
|
||||
auto sizes = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchMultipleColumnsWeights) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
auto bin_sizes = {2, 16, 256, 512};
|
||||
auto sizes = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
dmat->Info().weights_.HostVector() = GenerateRandomWeights(num_rows);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HistUitl, DeviceSketchWeights) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
auto bin_sizes = {2, 16, 256, 512};
|
||||
auto sizes = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
@ -274,8 +286,8 @@ TEST(HistUitl, DeviceSketchWeights) {
|
||||
h_weights.resize(num_rows);
|
||||
std::fill(h_weights.begin(), h_weights.end(), 1.0f);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
auto wcuts = DeviceSketch(0, weighted_dmat.get(), num_bins);
|
||||
auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
|
||||
auto wcuts = DeviceSketch(&ctx, weighted_dmat.get(), num_bins);
|
||||
ASSERT_EQ(cuts.MinValues(), wcuts.MinValues());
|
||||
ASSERT_EQ(cuts.Ptrs(), wcuts.Ptrs());
|
||||
ASSERT_EQ(cuts.Values(), wcuts.Values());
|
||||
@ -286,14 +298,15 @@ TEST(HistUitl, DeviceSketchWeights) {
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchBatches) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
int num_bins = 256;
|
||||
int num_rows = 5000;
|
||||
int batch_sizes[] = {0, 100, 1500, 6000};
|
||||
auto batch_sizes = {0, 100, 1500, 6000};
|
||||
int num_columns = 5;
|
||||
for (auto batch_size : batch_sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins, batch_size);
|
||||
auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins, batch_size);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
|
||||
@ -301,8 +314,8 @@ TEST(HistUtil, DeviceSketchBatches) {
|
||||
size_t batches = 16;
|
||||
auto x = GenerateRandom(num_rows * batches, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows * batches, num_columns);
|
||||
auto cuts_with_batches = DeviceSketch(0, dmat.get(), num_bins, num_rows);
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0);
|
||||
auto cuts_with_batches = DeviceSketch(&ctx, dmat.get(), num_bins, num_rows);
|
||||
auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins, 0);
|
||||
|
||||
auto const& cut_values_batched = cuts_with_batches.Values();
|
||||
auto const& cut_values = cuts.Values();
|
||||
@ -313,15 +326,16 @@ TEST(HistUtil, DeviceSketchBatches) {
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchMultipleColumnsExternal) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
auto bin_sizes = {2, 16, 256, 512};
|
||||
auto sizes = {100, 1000, 1500};
|
||||
int num_columns =5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
dmlc::TemporaryDirectory temp;
|
||||
auto dmat = GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, temp);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
@ -329,8 +343,9 @@ TEST(HistUtil, DeviceSketchMultipleColumnsExternal) {
|
||||
|
||||
// See https://github.com/dmlc/xgboost/issues/5866.
|
||||
TEST(HistUtil, DeviceSketchExternalMemoryWithWeights) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
auto bin_sizes = {2, 16, 256, 512};
|
||||
auto sizes = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
dmlc::TemporaryDirectory temp;
|
||||
for (auto num_rows : sizes) {
|
||||
@ -338,7 +353,7 @@ TEST(HistUtil, DeviceSketchExternalMemoryWithWeights) {
|
||||
auto dmat = GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, temp);
|
||||
dmat->Info().weights_.HostVector() = GenerateRandomWeights(num_rows);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
@ -504,9 +519,9 @@ void TestCategoricalSketchAdapter(size_t n, size_t num_categories,
|
||||
}
|
||||
|
||||
TEST(HistUtil, AdapterDeviceSketchCategorical) {
|
||||
int categorical_sizes[] = {2, 6, 8, 12};
|
||||
auto categorical_sizes = {2, 6, 8, 12};
|
||||
int num_bins = 256;
|
||||
int sizes[] = {25, 100, 1000};
|
||||
auto sizes = {25, 100, 1000};
|
||||
for (auto n : sizes) {
|
||||
for (auto num_categories : categorical_sizes) {
|
||||
auto x = GenerateRandomCategoricalSingleColumn(n, num_categories);
|
||||
@ -521,8 +536,8 @@ TEST(HistUtil, AdapterDeviceSketchCategorical) {
|
||||
}
|
||||
|
||||
TEST(HistUtil, AdapterDeviceSketchMultipleColumns) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
auto bin_sizes = {2, 16, 256, 512};
|
||||
auto sizes = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
@ -538,7 +553,7 @@ TEST(HistUtil, AdapterDeviceSketchMultipleColumns) {
|
||||
TEST(HistUtil, AdapterDeviceSketchBatches) {
|
||||
int num_bins = 256;
|
||||
int num_rows = 5000;
|
||||
int batch_sizes[] = {0, 100, 1500, 6000};
|
||||
auto batch_sizes = {0, 100, 1500, 6000};
|
||||
int num_columns = 5;
|
||||
for (auto batch_size : batch_sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
@ -619,14 +634,15 @@ TEST(HistUtil, GetColumnSize) {
|
||||
// Check sketching from adapter or DMatrix results in the same answer
|
||||
// Consistency here is useful for testing and user experience
|
||||
TEST(HistUtil, SketchingEquivalent) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
auto bin_sizes = {2, 16, 256, 512};
|
||||
auto sizes = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
auto dmat_cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
auto dmat_cuts = DeviceSketch(&ctx, dmat.get(), num_bins);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
||||
common::HistogramCuts adapter_cuts = MakeUnweightedCutsForTest(
|
||||
@ -641,21 +657,25 @@ TEST(HistUtil, SketchingEquivalent) {
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchFromGroupWeights) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
size_t constexpr kRows = 3000, kCols = 200, kBins = 256;
|
||||
size_t constexpr kGroups = 10;
|
||||
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
|
||||
|
||||
// sketch with group weight
|
||||
auto& h_weights = m->Info().weights_.HostVector();
|
||||
h_weights.resize(kRows);
|
||||
h_weights.resize(kGroups);
|
||||
std::fill(h_weights.begin(), h_weights.end(), 1.0f);
|
||||
std::vector<bst_group_t> groups(kGroups);
|
||||
for (size_t i = 0; i < kGroups; ++i) {
|
||||
groups[i] = kRows / kGroups;
|
||||
}
|
||||
m->SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
HistogramCuts weighted_cuts = DeviceSketch(0, m.get(), kBins, 0);
|
||||
HistogramCuts weighted_cuts = DeviceSketch(&ctx, m.get(), kBins, 0);
|
||||
|
||||
// sketch with no weight
|
||||
h_weights.clear();
|
||||
HistogramCuts cuts = DeviceSketch(0, m.get(), kBins, 0);
|
||||
HistogramCuts cuts = DeviceSketch(&ctx, m.get(), kBins, 0);
|
||||
|
||||
ASSERT_EQ(cuts.Values().size(), weighted_cuts.Values().size());
|
||||
ASSERT_EQ(cuts.MinValues().size(), weighted_cuts.MinValues().size());
|
||||
@ -723,9 +743,10 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
||||
ASSERT_EQ(cuts.Ptrs().size(), kCols + 1);
|
||||
ValidateCuts(cuts, dmat.get(), kBins);
|
||||
|
||||
auto cuda_ctx = MakeCUDACtx(0);
|
||||
if (with_group) {
|
||||
dmat->Info().weights_ = decltype(dmat->Info().weights_)(); // remove weight
|
||||
HistogramCuts non_weighted = DeviceSketch(0, dmat.get(), kBins, 0);
|
||||
HistogramCuts non_weighted = DeviceSketch(&cuda_ctx, dmat.get(), kBins, 0);
|
||||
for (size_t i = 0; i < cuts.Values().size(); ++i) {
|
||||
ASSERT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
|
||||
}
|
||||
@ -760,5 +781,156 @@ TEST(HistUtil, AdapterSketchFromWeights) {
|
||||
TestAdapterSketchFromWeights(false);
|
||||
TestAdapterSketchFromWeights(true);
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
namespace {
|
||||
class DeviceSketchWithHessianTest
|
||||
: public ::testing::TestWithParam<std::tuple<bool, bst_row_t, bst_bin_t>> {
|
||||
bst_feature_t n_features_ = 5;
|
||||
bst_group_t n_groups_{3};
|
||||
|
||||
auto GenerateHessian(Context const* ctx, bst_row_t n_samples) const {
|
||||
HostDeviceVector<float> hessian;
|
||||
auto& h_hess = hessian.HostVector();
|
||||
h_hess = GenerateRandomWeights(n_samples);
|
||||
std::mt19937 rng(0);
|
||||
std::shuffle(h_hess.begin(), h_hess.end(), rng);
|
||||
hessian.SetDevice(ctx->Device());
|
||||
return hessian;
|
||||
}
|
||||
|
||||
void CheckReg(Context const* ctx, std::shared_ptr<DMatrix> p_fmat, bst_bin_t n_bins,
|
||||
HostDeviceVector<float> const& hessian, std::vector<float> const& w,
|
||||
std::size_t n_elements) const {
|
||||
auto const& h_hess = hessian.ConstHostVector();
|
||||
{
|
||||
auto& h_weight = p_fmat->Info().weights_.HostVector();
|
||||
h_weight = w;
|
||||
}
|
||||
|
||||
HistogramCuts cuts_hess =
|
||||
DeviceSketchWithHessian(ctx, p_fmat.get(), n_bins, hessian.ConstDeviceSpan(), n_elements);
|
||||
ValidateCuts(cuts_hess, p_fmat.get(), n_bins);
|
||||
|
||||
// merge hessian
|
||||
{
|
||||
auto& h_weight = p_fmat->Info().weights_.HostVector();
|
||||
ASSERT_EQ(h_weight.size(), h_hess.size());
|
||||
for (std::size_t i = 0; i < h_weight.size(); ++i) {
|
||||
h_weight[i] = w[i] * h_hess[i];
|
||||
}
|
||||
}
|
||||
|
||||
HistogramCuts cuts_wh = DeviceSketch(ctx, p_fmat.get(), n_bins, n_elements);
|
||||
ValidateCuts(cuts_wh, p_fmat.get(), n_bins);
|
||||
ASSERT_EQ(cuts_hess.Values().size(), cuts_wh.Values().size());
|
||||
for (std::size_t i = 0; i < cuts_hess.Values().size(); ++i) {
|
||||
ASSERT_NEAR(cuts_wh.Values()[i], cuts_hess.Values()[i], kRtEps);
|
||||
}
|
||||
|
||||
p_fmat->Info().weights_.HostVector() = w;
|
||||
}
|
||||
|
||||
protected:
|
||||
Context ctx_ = MakeCUDACtx(0);
|
||||
|
||||
void TestLTR(Context const* ctx, bst_row_t n_samples, bst_bin_t n_bins,
|
||||
std::size_t n_elements) const {
|
||||
auto x = GenerateRandom(n_samples, n_features_);
|
||||
|
||||
std::vector<bst_group_t> gptr;
|
||||
gptr.resize(n_groups_ + 1, 0);
|
||||
gptr[1] = n_samples / n_groups_;
|
||||
gptr[2] = n_samples / n_groups_ + gptr[1];
|
||||
gptr.back() = n_samples;
|
||||
|
||||
auto hessian = this->GenerateHessian(ctx, n_samples);
|
||||
auto const& h_hess = hessian.ConstHostVector();
|
||||
auto p_fmat = GetDMatrixFromData(x, n_samples, n_features_);
|
||||
p_fmat->Info().group_ptr_ = gptr;
|
||||
|
||||
// test with constant group weight
|
||||
std::vector<float> w(n_groups_, 1.0f);
|
||||
p_fmat->Info().weights_.HostVector() = w;
|
||||
HistogramCuts cuts_hess =
|
||||
DeviceSketchWithHessian(ctx, p_fmat.get(), n_bins, hessian.ConstDeviceSpan(), n_elements);
|
||||
// make validation easier by converting it into sample weight.
|
||||
p_fmat->Info().weights_.HostVector() = h_hess;
|
||||
p_fmat->Info().group_ptr_.clear();
|
||||
ValidateCuts(cuts_hess, p_fmat.get(), n_bins);
|
||||
// restore ltr properties
|
||||
p_fmat->Info().weights_.HostVector() = w;
|
||||
p_fmat->Info().group_ptr_ = gptr;
|
||||
|
||||
// test with random group weight
|
||||
w = GenerateRandomWeights(n_groups_);
|
||||
p_fmat->Info().weights_.HostVector() = w;
|
||||
cuts_hess =
|
||||
DeviceSketchWithHessian(ctx, p_fmat.get(), n_bins, hessian.ConstDeviceSpan(), n_elements);
|
||||
// make validation easier by converting it into sample weight.
|
||||
p_fmat->Info().weights_.HostVector() = h_hess;
|
||||
p_fmat->Info().group_ptr_.clear();
|
||||
ValidateCuts(cuts_hess, p_fmat.get(), n_bins);
|
||||
|
||||
// merge hessian with sample weight
|
||||
p_fmat->Info().weights_.Resize(n_samples);
|
||||
p_fmat->Info().group_ptr_.clear();
|
||||
for (std::size_t i = 0; i < h_hess.size(); ++i) {
|
||||
auto gidx = dh::SegmentId(Span{gptr.data(), gptr.size()}, i);
|
||||
p_fmat->Info().weights_.HostVector()[i] = w[gidx] * h_hess[i];
|
||||
}
|
||||
auto cuts = DeviceSketch(ctx, p_fmat.get(), n_bins, n_elements);
|
||||
ValidateCuts(cuts, p_fmat.get(), n_bins);
|
||||
ASSERT_EQ(cuts.Values().size(), cuts_hess.Values().size());
|
||||
for (std::size_t i = 0; i < cuts.Values().size(); ++i) {
|
||||
EXPECT_NEAR(cuts.Values()[i], cuts_hess.Values()[i], 1e-4f);
|
||||
}
|
||||
}
|
||||
|
||||
void TestRegression(Context const* ctx, bst_row_t n_samples, bst_bin_t n_bins,
|
||||
std::size_t n_elements) const {
|
||||
auto x = GenerateRandom(n_samples, n_features_);
|
||||
auto p_fmat = GetDMatrixFromData(x, n_samples, n_features_);
|
||||
std::vector<float> w = GenerateRandomWeights(n_samples);
|
||||
|
||||
auto hessian = this->GenerateHessian(ctx, n_samples);
|
||||
|
||||
this->CheckReg(ctx, p_fmat, n_bins, hessian, w, n_elements);
|
||||
}
|
||||
};
|
||||
|
||||
auto MakeParamsForTest() {
|
||||
std::vector<bst_row_t> sizes = {1, 2, 256, 512, 1000, 1500};
|
||||
std::vector<bst_bin_t> bin_sizes = {2, 16, 256, 512};
|
||||
std::vector<std::tuple<bool, bst_row_t, bst_bin_t>> configs;
|
||||
for (auto n_samples : sizes) {
|
||||
for (auto n_bins : bin_sizes) {
|
||||
configs.emplace_back(true, n_samples, n_bins);
|
||||
configs.emplace_back(false, n_samples, n_bins);
|
||||
}
|
||||
}
|
||||
return configs;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST_P(DeviceSketchWithHessianTest, DeviceSketchWithHessian) {
|
||||
auto param = GetParam();
|
||||
auto n_samples = std::get<1>(param);
|
||||
auto n_bins = std::get<2>(param);
|
||||
if (std::get<0>(param)) {
|
||||
this->TestLTR(&ctx_, n_samples, n_bins, 0);
|
||||
this->TestLTR(&ctx_, n_samples, n_bins, 512);
|
||||
} else {
|
||||
this->TestRegression(&ctx_, n_samples, n_bins, 0);
|
||||
this->TestRegression(&ctx_, n_samples, n_bins, 512);
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
HistUtil, DeviceSketchWithHessianTest, ::testing::ValuesIn(MakeParamsForTest()),
|
||||
[](::testing::TestParamInfo<DeviceSketchWithHessianTest::ParamType> const& info) {
|
||||
auto task = std::get<0>(info.param) ? "ltr" : "reg";
|
||||
auto n_samples = std::to_string(std::get<1>(info.param));
|
||||
auto n_bins = std::to_string(std::get<2>(info.param));
|
||||
return std::string{task} + "_" + n_samples + "_" + n_bins;
|
||||
});
|
||||
} // namespace xgboost::common
|
||||
|
||||
@ -1,9 +1,14 @@
|
||||
/**
|
||||
* Copyright 2020-2023, XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include "test_quantile.h"
|
||||
#include "../helpers.h"
|
||||
|
||||
#include "../../../src/collective/communicator-inl.cuh"
|
||||
#include "../../../src/common/hist_util.cuh"
|
||||
#include "../../../src/common/quantile.cuh"
|
||||
#include "../../../src/data/device_adapter.cuh" // CupyAdapter
|
||||
#include "../helpers.h"
|
||||
#include "test_quantile.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace {
|
||||
@ -437,13 +442,13 @@ void TestColumnSplitBasic() {
|
||||
}()};
|
||||
|
||||
// Generate cuts for distributed environment.
|
||||
auto const device = rank;
|
||||
HistogramCuts distributed_cuts = common::DeviceSketch(device, m.get(), kBins);
|
||||
auto ctx = MakeCUDACtx(rank);
|
||||
HistogramCuts distributed_cuts = common::DeviceSketch(&ctx, m.get(), kBins);
|
||||
|
||||
// Generate cuts for single node environment
|
||||
collective::Finalize();
|
||||
CHECK_EQ(collective::GetWorldSize(), 1);
|
||||
HistogramCuts single_node_cuts = common::DeviceSketch(device, m.get(), kBins);
|
||||
HistogramCuts single_node_cuts = common::DeviceSketch(&ctx, m.get(), kBins);
|
||||
|
||||
auto const& sptrs = single_node_cuts.Ptrs();
|
||||
auto const& dptrs = distributed_cuts.Ptrs();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user