parent
61dd854a52
commit
34d4ab455e
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2022-2023 by XGBoost Contributors
|
* Copyright 2022-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_COMMON_ALGORITHM_CUH_
|
#ifndef XGBOOST_COMMON_ALGORITHM_CUH_
|
||||||
#define XGBOOST_COMMON_ALGORITHM_CUH_
|
#define XGBOOST_COMMON_ALGORITHM_CUH_
|
||||||
@ -258,5 +258,19 @@ void ArgSort(xgboost::Context const *ctx, xgboost::common::Span<U> keys,
|
|||||||
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice,
|
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice,
|
||||||
cuctx->Stream()));
|
cuctx->Stream()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename InIt, typename OutIt, typename Predicate>
|
||||||
|
void CopyIf(CUDAContext const *cuctx, InIt in_first, InIt in_second, OutIt out_first,
|
||||||
|
Predicate pred) {
|
||||||
|
// We loop over batches because thrust::copy_if can't deal with sizes > 2^31
|
||||||
|
// See thrust issue #1302, XGBoost #6822
|
||||||
|
size_t constexpr kMaxCopySize = std::numeric_limits<int>::max() / 2;
|
||||||
|
size_t length = std::distance(in_first, in_second);
|
||||||
|
for (size_t offset = 0; offset < length; offset += kMaxCopySize) {
|
||||||
|
auto begin_input = in_first + offset;
|
||||||
|
auto end_input = in_first + std::min(offset + kMaxCopySize, length);
|
||||||
|
out_first = thrust::copy_if(cuctx->CTP(), begin_input, end_input, out_first, pred);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace xgboost::common
|
} // namespace xgboost::common
|
||||||
#endif // XGBOOST_COMMON_ALGORITHM_CUH_
|
#endif // XGBOOST_COMMON_ALGORITHM_CUH_
|
||||||
|
|||||||
@ -637,9 +637,8 @@ struct SegmentedUniqueReduceOp {
|
|||||||
* \return Number of unique values in total.
|
* \return Number of unique values in total.
|
||||||
*/
|
*/
|
||||||
template <typename DerivedPolicy, typename KeyInIt, typename KeyOutIt, typename ValInIt,
|
template <typename DerivedPolicy, typename KeyInIt, typename KeyOutIt, typename ValInIt,
|
||||||
typename ValOutIt, typename CompValue, typename CompKey>
|
typename ValOutIt, typename CompValue, typename CompKey = thrust::equal_to<size_t>>
|
||||||
size_t
|
size_t SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
|
||||||
SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
|
|
||||||
KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first,
|
KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first,
|
||||||
ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out,
|
ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out,
|
||||||
CompValue comp, CompKey comp_key = thrust::equal_to<size_t>{}) {
|
CompValue comp, CompKey comp_key = thrust::equal_to<size_t>{}) {
|
||||||
@ -676,16 +675,6 @@ SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec
|
|||||||
return n_uniques;
|
return n_uniques;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... Inputs,
|
|
||||||
std::enable_if_t<std::tuple_size<std::tuple<Inputs...>>::value == 7>
|
|
||||||
* = nullptr>
|
|
||||||
size_t SegmentedUnique(Inputs &&...inputs) {
|
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
|
||||||
return SegmentedUnique(thrust::cuda::par(alloc),
|
|
||||||
std::forward<Inputs &&>(inputs)...,
|
|
||||||
thrust::equal_to<size_t>{});
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Unique by key for many groups of data. Has same constraint as `SegmentedUnique`.
|
* \brief Unique by key for many groups of data. Has same constraint as `SegmentedUnique`.
|
||||||
*
|
*
|
||||||
@ -793,21 +782,6 @@ void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename InIt, typename OutIt, typename Predicate>
|
|
||||||
void CopyIf(InIt in_first, InIt in_second, OutIt out_first, Predicate pred) {
|
|
||||||
// We loop over batches because thrust::copy_if can't deal with sizes > 2^31
|
|
||||||
// See thrust issue #1302, XGBoost #6822
|
|
||||||
size_t constexpr kMaxCopySize = std::numeric_limits<int>::max() / 2;
|
|
||||||
size_t length = std::distance(in_first, in_second);
|
|
||||||
XGBCachingDeviceAllocator<char> alloc;
|
|
||||||
for (size_t offset = 0; offset < length; offset += kMaxCopySize) {
|
|
||||||
auto begin_input = in_first + offset;
|
|
||||||
auto end_input = in_first + std::min(offset + kMaxCopySize, length);
|
|
||||||
out_first = thrust::copy_if(thrust::cuda::par(alloc), begin_input,
|
|
||||||
end_input, out_first, pred);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename InputIteratorT, typename OutputIteratorT, typename OffsetT>
|
template <typename InputIteratorT, typename OutputIteratorT, typename OffsetT>
|
||||||
void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items) {
|
void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items) {
|
||||||
InclusiveScan(d_in, d_out, cub::Sum(), num_items);
|
InclusiveScan(d_in, d_out, cub::Sum(), num_items);
|
||||||
|
|||||||
@ -106,26 +106,27 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_idx_t num_ro
|
|||||||
return std::min(sketch_batch_num_elements, kIntMax);
|
return std::min(sketch_batch_num_elements, kIntMax);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SortByWeight(dh::device_vector<float>* weights, dh::device_vector<Entry>* sorted_entries) {
|
void SortByWeight(Context const* ctx, dh::device_vector<float>* weights,
|
||||||
|
dh::device_vector<Entry>* sorted_entries) {
|
||||||
// Sort both entries and wegihts.
|
// Sort both entries and wegihts.
|
||||||
dh::XGBDeviceAllocator<char> alloc;
|
auto cuctx = ctx->CUDACtx();
|
||||||
CHECK_EQ(weights->size(), sorted_entries->size());
|
CHECK_EQ(weights->size(), sorted_entries->size());
|
||||||
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(), sorted_entries->end(),
|
thrust::sort_by_key(cuctx->TP(), sorted_entries->begin(), sorted_entries->end(), weights->begin(),
|
||||||
weights->begin(), detail::EntryCompareOp());
|
detail::EntryCompareOp());
|
||||||
|
|
||||||
// Scan weights
|
// Scan weights
|
||||||
dh::XGBCachingDeviceAllocator<char> caching;
|
|
||||||
thrust::inclusive_scan_by_key(
|
thrust::inclusive_scan_by_key(
|
||||||
thrust::cuda::par(caching), sorted_entries->begin(), sorted_entries->end(), weights->begin(),
|
cuctx->CTP(), sorted_entries->begin(), sorted_entries->end(), weights->begin(),
|
||||||
weights->begin(),
|
weights->begin(),
|
||||||
[=] __device__(const Entry& a, const Entry& b) { return a.index == b.index; });
|
[=] __device__(const Entry& a, const Entry& b) { return a.index == b.index; });
|
||||||
}
|
}
|
||||||
|
|
||||||
void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst_idx_t> d_cuts_ptr,
|
void RemoveDuplicatedCategories(Context const* ctx, MetaInfo const& info,
|
||||||
|
Span<bst_idx_t> d_cuts_ptr,
|
||||||
dh::device_vector<Entry>* p_sorted_entries,
|
dh::device_vector<Entry>* p_sorted_entries,
|
||||||
dh::device_vector<float>* p_sorted_weights,
|
dh::device_vector<float>* p_sorted_weights,
|
||||||
dh::caching_device_vector<size_t>* p_column_sizes_scan) {
|
dh::caching_device_vector<size_t>* p_column_sizes_scan) {
|
||||||
info.feature_types.SetDevice(device);
|
info.feature_types.SetDevice(ctx->Device());
|
||||||
auto d_feature_types = info.feature_types.ConstDeviceSpan();
|
auto d_feature_types = info.feature_types.ConstDeviceSpan();
|
||||||
CHECK(!d_feature_types.empty());
|
CHECK(!d_feature_types.empty());
|
||||||
auto& column_sizes_scan = *p_column_sizes_scan;
|
auto& column_sizes_scan = *p_column_sizes_scan;
|
||||||
@ -142,10 +143,11 @@ void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst
|
|||||||
auto d_sorted_weights = dh::ToSpan(*p_sorted_weights);
|
auto d_sorted_weights = dh::ToSpan(*p_sorted_weights);
|
||||||
auto val_in_it = thrust::make_zip_iterator(d_sorted_entries.data(), d_sorted_weights.data());
|
auto val_in_it = thrust::make_zip_iterator(d_sorted_entries.data(), d_sorted_weights.data());
|
||||||
auto val_out_it = thrust::make_zip_iterator(d_sorted_entries.data(), d_sorted_weights.data());
|
auto val_out_it = thrust::make_zip_iterator(d_sorted_entries.data(), d_sorted_weights.data());
|
||||||
n_uniques = dh::SegmentedUnique(
|
n_uniques =
|
||||||
column_sizes_scan.data().get(), column_sizes_scan.data().get() + column_sizes_scan.size(),
|
dh::SegmentedUnique(ctx->CUDACtx()->CTP(), column_sizes_scan.data().get(),
|
||||||
val_in_it, val_in_it + sorted_entries.size(), new_column_scan.data().get(), val_out_it,
|
column_sizes_scan.data().get() + column_sizes_scan.size(), val_in_it,
|
||||||
[=] __device__(Pair const& l, Pair const& r) {
|
val_in_it + sorted_entries.size(), new_column_scan.data().get(),
|
||||||
|
val_out_it, [=] __device__(Pair const& l, Pair const& r) {
|
||||||
Entry const& le = thrust::get<0>(l);
|
Entry const& le = thrust::get<0>(l);
|
||||||
Entry const& re = thrust::get<0>(r);
|
Entry const& re = thrust::get<0>(r);
|
||||||
if (le.index == re.index && IsCat(d_feature_types, le.index)) {
|
if (le.index == re.index && IsCat(d_feature_types, le.index)) {
|
||||||
@ -155,10 +157,11 @@ void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst
|
|||||||
});
|
});
|
||||||
p_sorted_weights->resize(n_uniques);
|
p_sorted_weights->resize(n_uniques);
|
||||||
} else {
|
} else {
|
||||||
n_uniques = dh::SegmentedUnique(
|
n_uniques = dh::SegmentedUnique(ctx->CUDACtx()->CTP(), column_sizes_scan.data().get(),
|
||||||
column_sizes_scan.data().get(), column_sizes_scan.data().get() + column_sizes_scan.size(),
|
column_sizes_scan.data().get() + column_sizes_scan.size(),
|
||||||
sorted_entries.begin(), sorted_entries.end(), new_column_scan.data().get(),
|
sorted_entries.begin(), sorted_entries.end(),
|
||||||
sorted_entries.begin(), [=] __device__(Entry const& l, Entry const& r) {
|
new_column_scan.data().get(), sorted_entries.begin(),
|
||||||
|
[=] __device__(Entry const& l, Entry const& r) {
|
||||||
if (l.index == r.index) {
|
if (l.index == r.index) {
|
||||||
if (IsCat(d_feature_types, l.index)) {
|
if (IsCat(d_feature_types, l.index)) {
|
||||||
return l.fvalue == r.fvalue;
|
return l.fvalue == r.fvalue;
|
||||||
@ -189,7 +192,7 @@ void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
// Turn size into ptr.
|
// Turn size into ptr.
|
||||||
thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(), new_cuts_size.cend(),
|
thrust::exclusive_scan(ctx->CUDACtx()->CTP(), new_cuts_size.cbegin(), new_cuts_size.cend(),
|
||||||
d_cuts_ptr.data());
|
d_cuts_ptr.data());
|
||||||
}
|
}
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
@ -225,7 +228,7 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c
|
|||||||
std::size_t ridx = dh::SegmentId(row_ptrs, element_idx);
|
std::size_t ridx = dh::SegmentId(row_ptrs, element_idx);
|
||||||
d_temp_weight[idx] = sample_weight[ridx + base_rowid];
|
d_temp_weight[idx] = sample_weight[ridx + base_rowid];
|
||||||
});
|
});
|
||||||
detail::SortByWeight(&entry_weight, &sorted_entries);
|
detail::SortByWeight(ctx, &entry_weight, &sorted_entries);
|
||||||
} else {
|
} else {
|
||||||
thrust::sort(cuctx->TP(), sorted_entries.begin(), sorted_entries.end(),
|
thrust::sort(cuctx->TP(), sorted_entries.begin(), sorted_entries.end(),
|
||||||
detail::EntryCompareOp());
|
detail::EntryCompareOp());
|
||||||
@ -238,13 +241,13 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c
|
|||||||
sorted_entries.data().get(), [] __device__(Entry const& e) -> data::COOTuple {
|
sorted_entries.data().get(), [] __device__(Entry const& e) -> data::COOTuple {
|
||||||
return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size.
|
return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size.
|
||||||
});
|
});
|
||||||
detail::GetColumnSizesScan(ctx->Device(), info.num_col_, num_cuts_per_feature,
|
detail::GetColumnSizesScan(ctx->CUDACtx(), ctx->Device(), info.num_col_, num_cuts_per_feature,
|
||||||
IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr,
|
IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr,
|
||||||
&column_sizes_scan);
|
&column_sizes_scan);
|
||||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||||
if (sketch_container->HasCategorical()) {
|
if (sketch_container->HasCategorical()) {
|
||||||
auto p_weight = entry_weight.empty() ? nullptr : &entry_weight;
|
auto p_weight = entry_weight.empty() ? nullptr : &entry_weight;
|
||||||
detail::RemoveDuplicatedCategories(ctx->Device(), info, d_cuts_ptr, &sorted_entries, p_weight,
|
detail::RemoveDuplicatedCategories(ctx, info, d_cuts_ptr, &sorted_entries, p_weight,
|
||||||
&column_sizes_scan);
|
&column_sizes_scan);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -252,7 +255,7 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c
|
|||||||
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
|
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
|
||||||
|
|
||||||
// Add cuts into sketches
|
// Add cuts into sketches
|
||||||
sketch_container->Push(dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
|
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
|
||||||
h_cuts_ptr.back(), dh::ToSpan(entry_weight));
|
h_cuts_ptr.back(), dh::ToSpan(entry_weight));
|
||||||
|
|
||||||
sorted_entries.clear();
|
sorted_entries.clear();
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2020-2023 by XGBoost contributors
|
* Copyright 2020-2024, XGBoost contributors
|
||||||
*
|
*
|
||||||
* \brief Front end and utilities for GPU based sketching. Works on sliding window
|
* \brief Front end and utilities for GPU based sketching. Works on sliding window
|
||||||
* instead of stream.
|
* instead of stream.
|
||||||
@ -13,6 +13,8 @@
|
|||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
|
|
||||||
#include "../data/adapter.h" // for IsValidFunctor
|
#include "../data/adapter.h" // for IsValidFunctor
|
||||||
|
#include "algorithm.cuh" // for CopyIf
|
||||||
|
#include "cuda_context.cuh" // for CUDAContext
|
||||||
#include "device_helpers.cuh"
|
#include "device_helpers.cuh"
|
||||||
#include "hist_util.h"
|
#include "hist_util.h"
|
||||||
#include "quantile.cuh"
|
#include "quantile.cuh"
|
||||||
@ -107,9 +109,10 @@ std::uint32_t EstimateGridSize(DeviceOrd device, Kernel kernel, std::size_t shar
|
|||||||
* \param out_column_size Output buffer for the size of each column.
|
* \param out_column_size Output buffer for the size of each column.
|
||||||
*/
|
*/
|
||||||
template <typename BatchIt, bool force_use_global_memory = false, bool force_use_u64 = false>
|
template <typename BatchIt, bool force_use_global_memory = false, bool force_use_u64 = false>
|
||||||
void LaunchGetColumnSizeKernel(DeviceOrd device, IterSpan<BatchIt> batch_iter,
|
void LaunchGetColumnSizeKernel(CUDAContext const* cuctx, DeviceOrd device,
|
||||||
data::IsValidFunctor is_valid, Span<std::size_t> out_column_size) {
|
IterSpan<BatchIt> batch_iter, data::IsValidFunctor is_valid,
|
||||||
thrust::fill_n(thrust::device, dh::tbegin(out_column_size), out_column_size.size(), 0);
|
Span<std::size_t> out_column_size) {
|
||||||
|
thrust::fill_n(cuctx->CTP(), dh::tbegin(out_column_size), out_column_size.size(), 0);
|
||||||
|
|
||||||
std::size_t max_shared_memory = dh::MaxSharedMemory(device.ordinal);
|
std::size_t max_shared_memory = dh::MaxSharedMemory(device.ordinal);
|
||||||
// Not strictly correct as we should use number of samples to determine the type of
|
// Not strictly correct as we should use number of samples to determine the type of
|
||||||
@ -135,17 +138,17 @@ void LaunchGetColumnSizeKernel(DeviceOrd device, IterSpan<BatchIt> batch_iter,
|
|||||||
CHECK(!force_use_u64);
|
CHECK(!force_use_u64);
|
||||||
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::uint32_t, BatchIt>;
|
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::uint32_t, BatchIt>;
|
||||||
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
|
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
|
||||||
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory}(
|
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, cuctx->Stream()}(
|
||||||
kernel, batch_iter, is_valid, out_column_size);
|
kernel, batch_iter, is_valid, out_column_size);
|
||||||
} else {
|
} else {
|
||||||
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::size_t, BatchIt>;
|
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::size_t, BatchIt>;
|
||||||
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
|
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
|
||||||
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory}(
|
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, cuctx->Stream()}(
|
||||||
kernel, batch_iter, is_valid, out_column_size);
|
kernel, batch_iter, is_valid, out_column_size);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
auto d_out_column_size = out_column_size;
|
auto d_out_column_size = out_column_size;
|
||||||
dh::LaunchN(batch_iter.size(), [=] __device__(size_t idx) {
|
dh::LaunchN(batch_iter.size(), cuctx->Stream(), [=] __device__(size_t idx) {
|
||||||
auto e = batch_iter[idx];
|
auto e = batch_iter[idx];
|
||||||
if (is_valid(e)) {
|
if (is_valid(e)) {
|
||||||
atomicAdd(&d_out_column_size[e.column_idx], static_cast<size_t>(1));
|
atomicAdd(&d_out_column_size[e.column_idx], static_cast<size_t>(1));
|
||||||
@ -155,26 +158,26 @@ void LaunchGetColumnSizeKernel(DeviceOrd device, IterSpan<BatchIt> batch_iter,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename BatchIt>
|
template <typename BatchIt>
|
||||||
void GetColumnSizesScan(DeviceOrd device, size_t num_columns, std::size_t num_cuts_per_feature,
|
void GetColumnSizesScan(CUDAContext const* cuctx, DeviceOrd device, size_t num_columns,
|
||||||
IterSpan<BatchIt> batch_iter, data::IsValidFunctor is_valid,
|
std::size_t num_cuts_per_feature, IterSpan<BatchIt> batch_iter,
|
||||||
|
data::IsValidFunctor is_valid,
|
||||||
HostDeviceVector<SketchContainer::OffsetT>* cuts_ptr,
|
HostDeviceVector<SketchContainer::OffsetT>* cuts_ptr,
|
||||||
dh::caching_device_vector<size_t>* column_sizes_scan) {
|
dh::caching_device_vector<size_t>* column_sizes_scan) {
|
||||||
column_sizes_scan->resize(num_columns + 1);
|
column_sizes_scan->resize(num_columns + 1);
|
||||||
cuts_ptr->SetDevice(device);
|
cuts_ptr->SetDevice(device);
|
||||||
cuts_ptr->Resize(num_columns + 1, 0);
|
cuts_ptr->Resize(num_columns + 1, 0);
|
||||||
|
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
|
||||||
auto d_column_sizes_scan = dh::ToSpan(*column_sizes_scan);
|
auto d_column_sizes_scan = dh::ToSpan(*column_sizes_scan);
|
||||||
LaunchGetColumnSizeKernel(device, batch_iter, is_valid, d_column_sizes_scan);
|
LaunchGetColumnSizeKernel(cuctx, device, batch_iter, is_valid, d_column_sizes_scan);
|
||||||
// Calculate cuts CSC pointer
|
// Calculate cuts CSC pointer
|
||||||
auto cut_ptr_it = dh::MakeTransformIterator<size_t>(
|
auto cut_ptr_it = dh::MakeTransformIterator<size_t>(
|
||||||
column_sizes_scan->begin(), [=] __device__(size_t column_size) {
|
column_sizes_scan->begin(), [=] __device__(size_t column_size) {
|
||||||
return thrust::min(num_cuts_per_feature, column_size);
|
return thrust::min(num_cuts_per_feature, column_size);
|
||||||
});
|
});
|
||||||
thrust::exclusive_scan(thrust::cuda::par(alloc), cut_ptr_it,
|
thrust::exclusive_scan(cuctx->CTP(), cut_ptr_it,
|
||||||
cut_ptr_it + column_sizes_scan->size(), cuts_ptr->DevicePointer());
|
cut_ptr_it + column_sizes_scan->size(), cuts_ptr->DevicePointer());
|
||||||
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
|
thrust::exclusive_scan(cuctx->CTP(), column_sizes_scan->begin(), column_sizes_scan->end(),
|
||||||
column_sizes_scan->end(), column_sizes_scan->begin());
|
column_sizes_scan->begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
inline size_t constexpr BytesPerElement(bool has_weight) {
|
inline size_t constexpr BytesPerElement(bool has_weight) {
|
||||||
@ -215,9 +218,9 @@ size_t RequiredMemory(bst_idx_t num_rows, bst_feature_t num_columns, size_t nnz,
|
|||||||
|
|
||||||
// Count the valid entries in each column and copy them out.
|
// Count the valid entries in each column and copy them out.
|
||||||
template <typename AdapterBatch, typename BatchIter>
|
template <typename AdapterBatch, typename BatchIter>
|
||||||
void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, Range1d range,
|
void MakeEntriesFromAdapter(CUDAContext const* cuctx, AdapterBatch const& batch,
|
||||||
float missing, size_t columns, size_t cuts_per_feature,
|
BatchIter batch_iter, Range1d range, float missing, size_t columns,
|
||||||
DeviceOrd device,
|
size_t cuts_per_feature, DeviceOrd device,
|
||||||
HostDeviceVector<SketchContainer::OffsetT>* cut_sizes_scan,
|
HostDeviceVector<SketchContainer::OffsetT>* cut_sizes_scan,
|
||||||
dh::caching_device_vector<size_t>* column_sizes_scan,
|
dh::caching_device_vector<size_t>* column_sizes_scan,
|
||||||
dh::device_vector<Entry>* sorted_entries) {
|
dh::device_vector<Entry>* sorted_entries) {
|
||||||
@ -229,19 +232,20 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, Ran
|
|||||||
auto span = IterSpan{batch_iter + range.begin(), n};
|
auto span = IterSpan{batch_iter + range.begin(), n};
|
||||||
data::IsValidFunctor is_valid(missing);
|
data::IsValidFunctor is_valid(missing);
|
||||||
// Work out how many valid entries we have in each column
|
// Work out how many valid entries we have in each column
|
||||||
GetColumnSizesScan(device, columns, cuts_per_feature, span, is_valid, cut_sizes_scan,
|
GetColumnSizesScan(cuctx, device, columns, cuts_per_feature, span, is_valid, cut_sizes_scan,
|
||||||
column_sizes_scan);
|
column_sizes_scan);
|
||||||
size_t num_valid = column_sizes_scan->back();
|
size_t num_valid = column_sizes_scan->back();
|
||||||
// Copy current subset of valid elements into temporary storage and sort
|
// Copy current subset of valid elements into temporary storage and sort
|
||||||
sorted_entries->resize(num_valid);
|
sorted_entries->resize(num_valid);
|
||||||
dh::CopyIf(entry_iter + range.begin(), entry_iter + range.end(), sorted_entries->begin(),
|
CopyIf(cuctx, entry_iter + range.begin(), entry_iter + range.end(), sorted_entries->begin(),
|
||||||
is_valid);
|
is_valid);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SortByWeight(dh::device_vector<float>* weights,
|
void SortByWeight(Context const* ctx, dh::device_vector<float>* weights,
|
||||||
dh::device_vector<Entry>* sorted_entries);
|
dh::device_vector<Entry>* sorted_entries);
|
||||||
|
|
||||||
void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst_idx_t> d_cuts_ptr,
|
void RemoveDuplicatedCategories(Context const* ctx, MetaInfo const& info,
|
||||||
|
Span<bst_idx_t> d_cuts_ptr,
|
||||||
dh::device_vector<Entry>* p_sorted_entries,
|
dh::device_vector<Entry>* p_sorted_entries,
|
||||||
dh::device_vector<float>* p_sorted_weights,
|
dh::device_vector<float>* p_sorted_weights,
|
||||||
dh::caching_device_vector<size_t>* p_column_sizes_scan);
|
dh::caching_device_vector<size_t>* p_column_sizes_scan);
|
||||||
@ -278,10 +282,9 @@ inline HistogramCuts DeviceSketch(Context const* ctx, DMatrix* p_fmat, bst_bin_t
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename AdapterBatch>
|
template <typename AdapterBatch>
|
||||||
void ProcessSlidingWindow(AdapterBatch const &batch, MetaInfo const &info,
|
void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInfo const& info,
|
||||||
DeviceOrd device, size_t columns, size_t begin, size_t end,
|
size_t columns, size_t begin, size_t end, float missing,
|
||||||
float missing, SketchContainer *sketch_container,
|
SketchContainer* sketch_container, int num_cuts) {
|
||||||
int num_cuts) {
|
|
||||||
// Copy current subset of valid elements into temporary storage and sort
|
// Copy current subset of valid elements into temporary storage and sort
|
||||||
dh::device_vector<Entry> sorted_entries;
|
dh::device_vector<Entry> sorted_entries;
|
||||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||||
@ -289,39 +292,32 @@ void ProcessSlidingWindow(AdapterBatch const &batch, MetaInfo const &info,
|
|||||||
thrust::make_counting_iterator(0llu),
|
thrust::make_counting_iterator(0llu),
|
||||||
[=] __device__(size_t idx) { return batch.GetElement(idx); });
|
[=] __device__(size_t idx) { return batch.GetElement(idx); });
|
||||||
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
|
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
|
||||||
cuts_ptr.SetDevice(device);
|
cuts_ptr.SetDevice(ctx->Device());
|
||||||
detail::MakeEntriesFromAdapter(batch, batch_iter, {begin, end}, missing,
|
CUDAContext const* cuctx = ctx->CUDACtx();
|
||||||
columns, num_cuts, device,
|
detail::MakeEntriesFromAdapter(cuctx, batch, batch_iter, {begin, end}, missing, columns, num_cuts,
|
||||||
&cuts_ptr,
|
ctx->Device(), &cuts_ptr, &column_sizes_scan, &sorted_entries);
|
||||||
&column_sizes_scan,
|
thrust::sort(cuctx->TP(), sorted_entries.begin(), sorted_entries.end(), detail::EntryCompareOp());
|
||||||
&sorted_entries);
|
|
||||||
dh::XGBDeviceAllocator<char> alloc;
|
|
||||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
|
||||||
sorted_entries.end(), detail::EntryCompareOp());
|
|
||||||
|
|
||||||
if (sketch_container->HasCategorical()) {
|
if (sketch_container->HasCategorical()) {
|
||||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||||
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, nullptr,
|
detail::RemoveDuplicatedCategories(ctx, info, d_cuts_ptr, &sorted_entries, nullptr,
|
||||||
&column_sizes_scan);
|
&column_sizes_scan);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||||
auto const &h_cuts_ptr = cuts_ptr.HostVector();
|
auto const &h_cuts_ptr = cuts_ptr.HostVector();
|
||||||
// Extract the cuts from all columns concurrently
|
// Extract the cuts from all columns concurrently
|
||||||
sketch_container->Push(dh::ToSpan(sorted_entries),
|
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
|
||||||
dh::ToSpan(column_sizes_scan), d_cuts_ptr,
|
|
||||||
h_cuts_ptr.back());
|
h_cuts_ptr.back());
|
||||||
sorted_entries.clear();
|
sorted_entries.clear();
|
||||||
sorted_entries.shrink_to_fit();
|
sorted_entries.shrink_to_fit();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Batch>
|
template <typename Batch>
|
||||||
void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo const& info,
|
||||||
int num_cuts_per_feature,
|
int num_cuts_per_feature, bool is_ranking, float missing,
|
||||||
bool is_ranking, float missing, DeviceOrd device,
|
DeviceOrd device, size_t columns, size_t begin, size_t end,
|
||||||
size_t columns, size_t begin, size_t end,
|
|
||||||
SketchContainer* sketch_container) {
|
SketchContainer* sketch_container) {
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
|
||||||
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
||||||
info.weights_.SetDevice(device);
|
info.weights_.SetDevice(device);
|
||||||
auto weights = info.weights_.ConstDeviceSpan();
|
auto weights = info.weights_.ConstDeviceSpan();
|
||||||
@ -329,14 +325,12 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
|||||||
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
||||||
thrust::make_counting_iterator(0llu),
|
thrust::make_counting_iterator(0llu),
|
||||||
[=] __device__(size_t idx) { return batch.GetElement(idx); });
|
[=] __device__(size_t idx) { return batch.GetElement(idx); });
|
||||||
|
auto cuctx = ctx->CUDACtx();
|
||||||
dh::device_vector<Entry> sorted_entries;
|
dh::device_vector<Entry> sorted_entries;
|
||||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||||
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
|
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
|
||||||
detail::MakeEntriesFromAdapter(batch, batch_iter,
|
detail::MakeEntriesFromAdapter(cuctx, batch, batch_iter, {begin, end}, missing, columns,
|
||||||
{begin, end}, missing,
|
num_cuts_per_feature, device, &cuts_ptr, &column_sizes_scan,
|
||||||
columns, num_cuts_per_feature, device,
|
|
||||||
&cuts_ptr,
|
|
||||||
&column_sizes_scan,
|
|
||||||
&sorted_entries);
|
&sorted_entries);
|
||||||
data::IsValidFunctor is_valid(missing);
|
data::IsValidFunctor is_valid(missing);
|
||||||
|
|
||||||
@ -355,7 +349,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
|||||||
bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx);
|
bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx);
|
||||||
return weights[group_idx];
|
return weights[group_idx];
|
||||||
});
|
});
|
||||||
auto retit = thrust::copy_if(thrust::cuda::par(alloc),
|
auto retit = thrust::copy_if(cuctx->CTP(),
|
||||||
weight_iter + begin, weight_iter + end,
|
weight_iter + begin, weight_iter + end,
|
||||||
batch_iter + begin,
|
batch_iter + begin,
|
||||||
d_temp_weights.data(), // output
|
d_temp_weights.data(), // output
|
||||||
@ -368,7 +362,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
|||||||
[=]__device__(size_t idx) -> float {
|
[=]__device__(size_t idx) -> float {
|
||||||
return weights[batch.GetElement(idx).row_idx];
|
return weights[batch.GetElement(idx).row_idx];
|
||||||
});
|
});
|
||||||
auto retit = thrust::copy_if(thrust::cuda::par(alloc),
|
auto retit = thrust::copy_if(cuctx->CTP(),
|
||||||
weight_iter + begin, weight_iter + end,
|
weight_iter + begin, weight_iter + end,
|
||||||
batch_iter + begin,
|
batch_iter + begin,
|
||||||
d_temp_weights.data(), // output
|
d_temp_weights.data(), // output
|
||||||
@ -376,11 +370,11 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
|||||||
CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size());
|
CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
detail::SortByWeight(&temp_weights, &sorted_entries);
|
detail::SortByWeight(ctx, &temp_weights, &sorted_entries);
|
||||||
|
|
||||||
if (sketch_container->HasCategorical()) {
|
if (sketch_container->HasCategorical()) {
|
||||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||||
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, &temp_weights,
|
detail::RemoveDuplicatedCategories(ctx, info, d_cuts_ptr, &sorted_entries, &temp_weights,
|
||||||
&column_sizes_scan);
|
&column_sizes_scan);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -388,8 +382,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
|||||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||||
|
|
||||||
// Extract cuts
|
// Extract cuts
|
||||||
sketch_container->Push(dh::ToSpan(sorted_entries),
|
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
|
||||||
dh::ToSpan(column_sizes_scan), d_cuts_ptr,
|
|
||||||
h_cuts_ptr.back(), dh::ToSpan(temp_weights));
|
h_cuts_ptr.back(), dh::ToSpan(temp_weights));
|
||||||
sorted_entries.clear();
|
sorted_entries.clear();
|
||||||
sorted_entries.shrink_to_fit();
|
sorted_entries.shrink_to_fit();
|
||||||
@ -407,8 +400,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
|||||||
* testing.
|
* testing.
|
||||||
*/
|
*/
|
||||||
template <typename Batch>
|
template <typename Batch>
|
||||||
void AdapterDeviceSketch(Batch batch, int num_bins,
|
void AdapterDeviceSketch(Context const* ctx, Batch batch, int num_bins, MetaInfo const& info,
|
||||||
MetaInfo const& info,
|
|
||||||
float missing, SketchContainer* sketch_container,
|
float missing, SketchContainer* sketch_container,
|
||||||
size_t sketch_batch_num_elements = 0) {
|
size_t sketch_batch_num_elements = 0) {
|
||||||
size_t num_rows = batch.NumRows();
|
size_t num_rows = batch.NumRows();
|
||||||
@ -419,27 +411,24 @@ void AdapterDeviceSketch(Batch batch, int num_bins,
|
|||||||
|
|
||||||
if (weighted) {
|
if (weighted) {
|
||||||
sketch_batch_num_elements = detail::SketchBatchNumElements(
|
sketch_batch_num_elements = detail::SketchBatchNumElements(
|
||||||
sketch_batch_num_elements,
|
sketch_batch_num_elements, num_rows, num_cols, std::numeric_limits<size_t>::max(),
|
||||||
num_rows, num_cols, std::numeric_limits<size_t>::max(),
|
|
||||||
device.ordinal, num_cuts_per_feature, true);
|
device.ordinal, num_cuts_per_feature, true);
|
||||||
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
|
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
|
||||||
size_t end =
|
size_t end =
|
||||||
std::min(batch.Size(), static_cast<std::size_t>(begin + sketch_batch_num_elements));
|
std::min(batch.Size(), static_cast<std::size_t>(begin + sketch_batch_num_elements));
|
||||||
ProcessWeightedSlidingWindow(batch, info,
|
ProcessWeightedSlidingWindow(ctx, batch, info, num_cuts_per_feature,
|
||||||
num_cuts_per_feature,
|
HostSketchContainer::UseGroup(info), missing, device, num_cols,
|
||||||
HostSketchContainer::UseGroup(info), missing, device, num_cols, begin, end,
|
begin, end, sketch_container);
|
||||||
sketch_container);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
sketch_batch_num_elements = detail::SketchBatchNumElements(
|
sketch_batch_num_elements = detail::SketchBatchNumElements(
|
||||||
sketch_batch_num_elements,
|
sketch_batch_num_elements, num_rows, num_cols, std::numeric_limits<size_t>::max(),
|
||||||
num_rows, num_cols, std::numeric_limits<size_t>::max(),
|
|
||||||
device.ordinal, num_cuts_per_feature, false);
|
device.ordinal, num_cuts_per_feature, false);
|
||||||
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
|
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
|
||||||
size_t end =
|
size_t end =
|
||||||
std::min(batch.Size(), static_cast<std::size_t>(begin + sketch_batch_num_elements));
|
std::min(batch.Size(), static_cast<std::size_t>(begin + sketch_batch_num_elements));
|
||||||
ProcessSlidingWindow(batch, info, device, num_cols, begin, end, missing,
|
ProcessSlidingWindow(ctx, batch, info, num_cols, begin, end, missing, sketch_container,
|
||||||
sketch_container, num_cuts_per_feature);
|
num_cuts_per_feature);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,6 +18,8 @@
|
|||||||
#include "../collective/communicator-inl.h" // for GetWorldSize, GetRank
|
#include "../collective/communicator-inl.h" // for GetWorldSize, GetRank
|
||||||
#include "categorical.h"
|
#include "categorical.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
#include "cuda_context.cuh" // for CUDAContext
|
||||||
|
#include "cuda_rt_utils.h" // for SetDevice
|
||||||
#include "device_helpers.cuh"
|
#include "device_helpers.cuh"
|
||||||
#include "hist_util.h"
|
#include "hist_util.h"
|
||||||
#include "quantile.cuh"
|
#include "quantile.cuh"
|
||||||
@ -117,6 +119,7 @@ void CopyTo(Span<T> out, Span<U> src) {
|
|||||||
|
|
||||||
// Compute the merge path.
|
// Compute the merge path.
|
||||||
common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
|
common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
|
||||||
|
Context const* ctx,
|
||||||
Span<SketchEntry const> const &d_x, Span<bst_idx_t const> const &x_ptr,
|
Span<SketchEntry const> const &d_x, Span<bst_idx_t const> const &x_ptr,
|
||||||
Span<SketchEntry const> const &d_y, Span<bst_idx_t const> const &y_ptr,
|
Span<SketchEntry const> const &d_y, Span<bst_idx_t const> const &y_ptr,
|
||||||
Span<SketchEntry> out, Span<bst_idx_t> out_ptr) {
|
Span<SketchEntry> out, Span<bst_idx_t> out_ptr) {
|
||||||
@ -142,13 +145,12 @@ common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
|
|||||||
auto y_merge_val_it =
|
auto y_merge_val_it =
|
||||||
thrust::make_zip_iterator(thrust::make_tuple(b_ind_iter, place_holder));
|
thrust::make_zip_iterator(thrust::make_tuple(b_ind_iter, place_holder));
|
||||||
|
|
||||||
dh::XGBCachingDeviceAllocator<Tuple> alloc;
|
|
||||||
static_assert(sizeof(Tuple) == sizeof(SketchEntry));
|
static_assert(sizeof(Tuple) == sizeof(SketchEntry));
|
||||||
// We reuse the memory for storing merge path.
|
// We reuse the memory for storing merge path.
|
||||||
common::Span<Tuple> merge_path{reinterpret_cast<Tuple *>(out.data()), out.size()};
|
common::Span<Tuple> merge_path{reinterpret_cast<Tuple *>(out.data()), out.size()};
|
||||||
// Determine the merge path, 0 if element is from x, 1 if it's from y.
|
// Determine the merge path, 0 if element is from x, 1 if it's from y.
|
||||||
thrust::merge_by_key(
|
thrust::merge_by_key(
|
||||||
thrust::cuda::par(alloc), x_merge_key_it, x_merge_key_it + d_x.size(),
|
ctx->CUDACtx()->CTP(), x_merge_key_it, x_merge_key_it + d_x.size(),
|
||||||
y_merge_key_it, y_merge_key_it + d_y.size(), x_merge_val_it,
|
y_merge_key_it, y_merge_key_it + d_y.size(), x_merge_val_it,
|
||||||
y_merge_val_it, thrust::make_discard_iterator(), merge_path.data(),
|
y_merge_val_it, thrust::make_discard_iterator(), merge_path.data(),
|
||||||
[=] __device__(auto const &l, auto const &r) -> bool {
|
[=] __device__(auto const &l, auto const &r) -> bool {
|
||||||
@ -163,8 +165,7 @@ common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
|
|||||||
// Compute output ptr
|
// Compute output ptr
|
||||||
auto transform_it =
|
auto transform_it =
|
||||||
thrust::make_zip_iterator(thrust::make_tuple(x_ptr.data(), y_ptr.data()));
|
thrust::make_zip_iterator(thrust::make_tuple(x_ptr.data(), y_ptr.data()));
|
||||||
thrust::transform(
|
thrust::transform(ctx->CUDACtx()->CTP(), transform_it, transform_it + x_ptr.size(),
|
||||||
thrust::cuda::par(alloc), transform_it, transform_it + x_ptr.size(),
|
|
||||||
out_ptr.data(),
|
out_ptr.data(),
|
||||||
[] __device__(auto const &t) { return thrust::get<0>(t) + thrust::get<1>(t); });
|
[] __device__(auto const &t) { return thrust::get<0>(t) + thrust::get<1>(t); });
|
||||||
|
|
||||||
@ -194,7 +195,7 @@ common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
|
|||||||
// is landed into output as the first element in merge result. The scan result is the
|
// is landed into output as the first element in merge result. The scan result is the
|
||||||
// subscript of x and y.
|
// subscript of x and y.
|
||||||
thrust::exclusive_scan_by_key(
|
thrust::exclusive_scan_by_key(
|
||||||
thrust::cuda::par(alloc), scan_key_it, scan_key_it + merge_path.size(),
|
ctx->CUDACtx()->CTP(), scan_key_it, scan_key_it + merge_path.size(),
|
||||||
scan_val_it, merge_path.data(),
|
scan_val_it, merge_path.data(),
|
||||||
thrust::make_tuple<uint64_t, uint64_t>(0ul, 0ul),
|
thrust::make_tuple<uint64_t, uint64_t>(0ul, 0ul),
|
||||||
thrust::equal_to<size_t>{},
|
thrust::equal_to<size_t>{},
|
||||||
@ -209,18 +210,17 @@ common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
|
|||||||
// summary does the output element come from) result by definition of merged rank. So we
|
// summary does the output element come from) result by definition of merged rank. So we
|
||||||
// run it in 2 passes to obtain the merge path and then customize the standard merge
|
// run it in 2 passes to obtain the merge path and then customize the standard merge
|
||||||
// algorithm.
|
// algorithm.
|
||||||
void MergeImpl(DeviceOrd device, Span<SketchEntry const> const &d_x,
|
void MergeImpl(Context const *ctx, Span<SketchEntry const> const &d_x,
|
||||||
Span<bst_idx_t const> const &x_ptr, Span<SketchEntry const> const &d_y,
|
Span<bst_idx_t const> const &x_ptr, Span<SketchEntry const> const &d_y,
|
||||||
Span<bst_idx_t const> const &y_ptr, Span<SketchEntry> out, Span<bst_idx_t> out_ptr) {
|
Span<bst_idx_t const> const &y_ptr, Span<SketchEntry> out, Span<bst_idx_t> out_ptr) {
|
||||||
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
|
||||||
CHECK_EQ(d_x.size() + d_y.size(), out.size());
|
CHECK_EQ(d_x.size() + d_y.size(), out.size());
|
||||||
CHECK_EQ(x_ptr.size(), out_ptr.size());
|
CHECK_EQ(x_ptr.size(), out_ptr.size());
|
||||||
CHECK_EQ(y_ptr.size(), out_ptr.size());
|
CHECK_EQ(y_ptr.size(), out_ptr.size());
|
||||||
|
|
||||||
auto d_merge_path = MergePath(d_x, x_ptr, d_y, y_ptr, out, out_ptr);
|
auto d_merge_path = MergePath(ctx, d_x, x_ptr, d_y, y_ptr, out, out_ptr);
|
||||||
auto d_out = out;
|
auto d_out = out;
|
||||||
|
|
||||||
dh::LaunchN(d_out.size(), [=] __device__(size_t idx) {
|
dh::LaunchN(d_out.size(), ctx->CUDACtx()->Stream(), [=] __device__(size_t idx) {
|
||||||
auto column_id = dh::SegmentId(out_ptr, idx);
|
auto column_id = dh::SegmentId(out_ptr, idx);
|
||||||
idx -= out_ptr[column_id];
|
idx -= out_ptr[column_id];
|
||||||
|
|
||||||
@ -307,10 +307,9 @@ void MergeImpl(DeviceOrd device, Span<SketchEntry const> const &d_x,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void SketchContainer::Push(Span<Entry const> entries, Span<size_t> columns_ptr,
|
void SketchContainer::Push(Context const *ctx, Span<Entry const> entries, Span<size_t> columns_ptr,
|
||||||
common::Span<OffsetT> cuts_ptr,
|
common::Span<OffsetT> cuts_ptr, size_t total_cuts, Span<float> weights) {
|
||||||
size_t total_cuts, Span<float> weights) {
|
common::SetDevice(device_.ordinal);
|
||||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
|
||||||
Span<SketchEntry> out;
|
Span<SketchEntry> out;
|
||||||
dh::device_vector<SketchEntry> cuts;
|
dh::device_vector<SketchEntry> cuts;
|
||||||
bool first_window = this->Current().empty();
|
bool first_window = this->Current().empty();
|
||||||
@ -346,12 +345,12 @@ void SketchContainer::Push(Span<Entry const> entries, Span<size_t> columns_ptr,
|
|||||||
}; // NOLINT
|
}; // NOLINT
|
||||||
PruneImpl<Entry>(cuts_ptr, entries, columns_ptr, ft, out, to_sketch_entry);
|
PruneImpl<Entry>(cuts_ptr, entries, columns_ptr, ft, out, to_sketch_entry);
|
||||||
}
|
}
|
||||||
auto n_uniques = this->ScanInput(out, cuts_ptr);
|
auto n_uniques = this->ScanInput(ctx, out, cuts_ptr);
|
||||||
|
|
||||||
if (!first_window) {
|
if (!first_window) {
|
||||||
CHECK_EQ(this->columns_ptr_.Size(), cuts_ptr.size());
|
CHECK_EQ(this->columns_ptr_.Size(), cuts_ptr.size());
|
||||||
out = out.subspan(0, n_uniques);
|
out = out.subspan(0, n_uniques);
|
||||||
this->Merge(cuts_ptr, out);
|
this->Merge(ctx, cuts_ptr, out);
|
||||||
this->FixError();
|
this->FixError();
|
||||||
} else {
|
} else {
|
||||||
this->Current().resize(n_uniques);
|
this->Current().resize(n_uniques);
|
||||||
@ -363,7 +362,8 @@ void SketchContainer::Push(Span<Entry const> entries, Span<size_t> columns_ptr,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t SketchContainer::ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_columns_ptr_in) {
|
size_t SketchContainer::ScanInput(Context const *ctx, Span<SketchEntry> entries,
|
||||||
|
Span<OffsetT> d_columns_ptr_in) {
|
||||||
/* There are 2 types of duplication. First is duplicated feature values, which comes
|
/* There are 2 types of duplication. First is duplicated feature values, which comes
|
||||||
* from user input data. Second is duplicated sketching entries, which is generated by
|
* from user input data. Second is duplicated sketching entries, which is generated by
|
||||||
* pruning or merging. We preserve the first type and remove the second type.
|
* pruning or merging. We preserve the first type and remove the second type.
|
||||||
@ -371,7 +371,6 @@ size_t SketchContainer::ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_col
|
|||||||
timer_.Start(__func__);
|
timer_.Start(__func__);
|
||||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
||||||
CHECK_EQ(d_columns_ptr_in.size(), num_columns_ + 1);
|
CHECK_EQ(d_columns_ptr_in.size(), num_columns_ + 1);
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
|
||||||
|
|
||||||
auto key_it = dh::MakeTransformIterator<size_t>(
|
auto key_it = dh::MakeTransformIterator<size_t>(
|
||||||
thrust::make_reverse_iterator(thrust::make_counting_iterator(entries.size())),
|
thrust::make_reverse_iterator(thrust::make_counting_iterator(entries.size())),
|
||||||
@ -381,7 +380,7 @@ size_t SketchContainer::ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_col
|
|||||||
// Reverse scan to accumulate weights into first duplicated element on left.
|
// Reverse scan to accumulate weights into first duplicated element on left.
|
||||||
auto val_it = thrust::make_reverse_iterator(dh::tend(entries));
|
auto val_it = thrust::make_reverse_iterator(dh::tend(entries));
|
||||||
thrust::inclusive_scan_by_key(
|
thrust::inclusive_scan_by_key(
|
||||||
thrust::cuda::par(alloc), key_it, key_it + entries.size(),
|
ctx->CUDACtx()->CTP(), key_it, key_it + entries.size(),
|
||||||
val_it, val_it,
|
val_it, val_it,
|
||||||
thrust::equal_to<size_t>{},
|
thrust::equal_to<size_t>{},
|
||||||
[] __device__(SketchEntry const &r, SketchEntry const &l) {
|
[] __device__(SketchEntry const &r, SketchEntry const &l) {
|
||||||
@ -396,8 +395,8 @@ size_t SketchContainer::ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_col
|
|||||||
|
|
||||||
auto d_columns_ptr_out = columns_ptr_b_.DeviceSpan();
|
auto d_columns_ptr_out = columns_ptr_b_.DeviceSpan();
|
||||||
// thrust unique_by_key preserves the first element.
|
// thrust unique_by_key preserves the first element.
|
||||||
auto n_uniques = dh::SegmentedUnique(
|
auto n_uniques =
|
||||||
d_columns_ptr_in.data(),
|
dh::SegmentedUnique(ctx->CUDACtx()->CTP(), d_columns_ptr_in.data(),
|
||||||
d_columns_ptr_in.data() + d_columns_ptr_in.size(), entries.data(),
|
d_columns_ptr_in.data() + d_columns_ptr_in.size(), entries.data(),
|
||||||
entries.data() + entries.size(), d_columns_ptr_out.data(), entries.data(),
|
entries.data() + entries.size(), d_columns_ptr_out.data(), entries.data(),
|
||||||
detail::SketchUnique{});
|
detail::SketchUnique{});
|
||||||
@ -407,7 +406,7 @@ size_t SketchContainer::ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_col
|
|||||||
return n_uniques;
|
return n_uniques;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SketchContainer::Prune(size_t to) {
|
void SketchContainer::Prune(Context const* ctx, std::size_t to) {
|
||||||
timer_.Start(__func__);
|
timer_.Start(__func__);
|
||||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
||||||
|
|
||||||
@ -438,19 +437,19 @@ void SketchContainer::Prune(size_t to) {
|
|||||||
this->columns_ptr_.Copy(columns_ptr_b_);
|
this->columns_ptr_.Copy(columns_ptr_b_);
|
||||||
this->Alternate();
|
this->Alternate();
|
||||||
|
|
||||||
this->Unique();
|
this->Unique(ctx);
|
||||||
timer_.Stop(__func__);
|
timer_.Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SketchContainer::Merge(Span<OffsetT const> d_that_columns_ptr,
|
void SketchContainer::Merge(Context const *ctx, Span<OffsetT const> d_that_columns_ptr,
|
||||||
Span<SketchEntry const> that) {
|
Span<SketchEntry const> that) {
|
||||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
common::SetDevice(device_.ordinal);
|
||||||
timer_.Start(__func__);
|
timer_.Start(__func__);
|
||||||
if (this->Current().size() == 0) {
|
if (this->Current().size() == 0) {
|
||||||
CHECK_EQ(this->columns_ptr_.HostVector().back(), 0);
|
CHECK_EQ(this->columns_ptr_.HostVector().back(), 0);
|
||||||
CHECK_EQ(this->columns_ptr_.HostVector().size(), d_that_columns_ptr.size());
|
CHECK_EQ(this->columns_ptr_.HostVector().size(), d_that_columns_ptr.size());
|
||||||
CHECK_EQ(columns_ptr_.Size(), num_columns_ + 1);
|
CHECK_EQ(columns_ptr_.Size(), num_columns_ + 1);
|
||||||
thrust::copy(thrust::device, d_that_columns_ptr.data(),
|
thrust::copy(ctx->CUDACtx()->CTP(), d_that_columns_ptr.data(),
|
||||||
d_that_columns_ptr.data() + d_that_columns_ptr.size(),
|
d_that_columns_ptr.data() + d_that_columns_ptr.size(),
|
||||||
this->columns_ptr_.DevicePointer());
|
this->columns_ptr_.DevicePointer());
|
||||||
auto total = this->columns_ptr_.HostVector().back();
|
auto total = this->columns_ptr_.HostVector().back();
|
||||||
@ -463,7 +462,7 @@ void SketchContainer::Merge(Span<OffsetT const> d_that_columns_ptr,
|
|||||||
this->Other().resize(this->Current().size() + that.size());
|
this->Other().resize(this->Current().size() + that.size());
|
||||||
CHECK_EQ(d_that_columns_ptr.size(), this->columns_ptr_.Size());
|
CHECK_EQ(d_that_columns_ptr.size(), this->columns_ptr_.Size());
|
||||||
|
|
||||||
MergeImpl(device_, this->Data(), this->ColumnsPtr(), that, d_that_columns_ptr,
|
MergeImpl(ctx, this->Data(), this->ColumnsPtr(), that, d_that_columns_ptr,
|
||||||
dh::ToSpan(this->Other()), columns_ptr_b_.DeviceSpan());
|
dh::ToSpan(this->Other()), columns_ptr_b_.DeviceSpan());
|
||||||
this->columns_ptr_.Copy(columns_ptr_b_);
|
this->columns_ptr_.Copy(columns_ptr_b_);
|
||||||
CHECK_EQ(this->columns_ptr_.Size(), num_columns_ + 1);
|
CHECK_EQ(this->columns_ptr_.Size(), num_columns_ + 1);
|
||||||
@ -471,7 +470,7 @@ void SketchContainer::Merge(Span<OffsetT const> d_that_columns_ptr,
|
|||||||
|
|
||||||
if (this->HasCategorical()) {
|
if (this->HasCategorical()) {
|
||||||
auto d_feature_types = this->FeatureTypes().ConstDeviceSpan();
|
auto d_feature_types = this->FeatureTypes().ConstDeviceSpan();
|
||||||
this->Unique([d_feature_types] __device__(size_t l_fidx, size_t r_fidx) {
|
this->Unique(ctx, [d_feature_types] __device__(size_t l_fidx, size_t r_fidx) {
|
||||||
return l_fidx == r_fidx && IsCat(d_feature_types, l_fidx);
|
return l_fidx == r_fidx && IsCat(d_feature_types, l_fidx);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -517,7 +516,7 @@ void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) {
|
|||||||
SafeColl(rc);
|
SafeColl(rc);
|
||||||
bst_idx_t intermediate_num_cuts =
|
bst_idx_t intermediate_num_cuts =
|
||||||
std::min(global_sum_rows, static_cast<size_t>(num_bins_ * kFactor));
|
std::min(global_sum_rows, static_cast<size_t>(num_bins_ * kFactor));
|
||||||
this->Prune(intermediate_num_cuts);
|
this->Prune(ctx, intermediate_num_cuts);
|
||||||
|
|
||||||
auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
|
auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
|
||||||
CHECK_EQ(d_columns_ptr.size(), num_columns_ + 1);
|
CHECK_EQ(d_columns_ptr.size(), num_columns_ + 1);
|
||||||
@ -570,9 +569,8 @@ void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) {
|
|||||||
for (size_t i = 0; i < allworkers.size(); ++i) {
|
for (size_t i = 0; i < allworkers.size(); ++i) {
|
||||||
auto worker = allworkers[i];
|
auto worker = allworkers[i];
|
||||||
auto worker_ptr =
|
auto worker_ptr =
|
||||||
dh::ToSpan(gathered_ptrs)
|
dh::ToSpan(gathered_ptrs).subspan(i * d_columns_ptr.size(), d_columns_ptr.size());
|
||||||
.subspan(i * d_columns_ptr.size(), d_columns_ptr.size());
|
new_sketch.Merge(ctx, worker_ptr, worker);
|
||||||
new_sketch.Merge(worker_ptr, worker);
|
|
||||||
new_sketch.FixError();
|
new_sketch.FixError();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -602,7 +600,7 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
|
|||||||
this->AllReduce(ctx, is_column_split);
|
this->AllReduce(ctx, is_column_split);
|
||||||
|
|
||||||
// Prune to final number of bins.
|
// Prune to final number of bins.
|
||||||
this->Prune(num_bins_ + 1);
|
this->Prune(ctx, num_bins_ + 1);
|
||||||
this->FixError();
|
this->FixError();
|
||||||
|
|
||||||
// Set up inputs
|
// Set up inputs
|
||||||
@ -624,7 +622,6 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
|
|||||||
std::vector<SketchEntry> max_values;
|
std::vector<SketchEntry> max_values;
|
||||||
float max_cat{-1.f};
|
float max_cat{-1.f};
|
||||||
if (has_categorical_) {
|
if (has_categorical_) {
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
|
||||||
auto key_it = dh::MakeTransformIterator<bst_feature_t>(
|
auto key_it = dh::MakeTransformIterator<bst_feature_t>(
|
||||||
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) -> bst_feature_t {
|
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) -> bst_feature_t {
|
||||||
return dh::SegmentId(d_in_columns_ptr, i);
|
return dh::SegmentId(d_in_columns_ptr, i);
|
||||||
@ -651,7 +648,7 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
|
|||||||
dh::caching_device_vector<size_t> d_max_keys(d_in_columns_ptr.size() - 1);
|
dh::caching_device_vector<size_t> d_max_keys(d_in_columns_ptr.size() - 1);
|
||||||
dh::caching_device_vector<SketchEntry> d_max_values(d_in_columns_ptr.size() - 1);
|
dh::caching_device_vector<SketchEntry> d_max_values(d_in_columns_ptr.size() - 1);
|
||||||
auto new_end = thrust::reduce_by_key(
|
auto new_end = thrust::reduce_by_key(
|
||||||
thrust::cuda::par(alloc), key_it, key_it + in_cut_values.size(), val_it, d_max_keys.begin(),
|
ctx->CUDACtx()->CTP(), key_it, key_it + in_cut_values.size(), val_it, d_max_keys.begin(),
|
||||||
d_max_values.begin(), thrust::equal_to<bst_feature_t>{},
|
d_max_values.begin(), thrust::equal_to<bst_feature_t>{},
|
||||||
[] __device__(auto l, auto r) { return l.value > r.value ? l : r; });
|
[] __device__(auto l, auto r) { return l.value > r.value ? l : r; });
|
||||||
d_max_keys.erase(new_end.first, d_max_keys.end());
|
d_max_keys.erase(new_end.first, d_max_keys.end());
|
||||||
@ -661,7 +658,7 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
|
|||||||
SketchEntry default_entry{};
|
SketchEntry default_entry{};
|
||||||
dh::caching_device_vector<SketchEntry> d_max_results(d_in_columns_ptr.size() - 1,
|
dh::caching_device_vector<SketchEntry> d_max_results(d_in_columns_ptr.size() - 1,
|
||||||
default_entry);
|
default_entry);
|
||||||
thrust::scatter(thrust::cuda::par(alloc), d_max_values.begin(), d_max_values.end(),
|
thrust::scatter(ctx->CUDACtx()->CTP(), d_max_values.begin(), d_max_values.end(),
|
||||||
d_max_keys.begin(), d_max_results.begin());
|
d_max_keys.begin(), d_max_results.begin());
|
||||||
dh::CopyDeviceSpanToVector(&max_values, dh::ToSpan(d_max_results));
|
dh::CopyDeviceSpanToVector(&max_values, dh::ToSpan(d_max_results));
|
||||||
auto max_it = MakeIndexTransformIter([&](auto i) {
|
auto max_it = MakeIndexTransformIter([&](auto i) {
|
||||||
|
|||||||
@ -7,6 +7,7 @@
|
|||||||
#include <thrust/logical.h> // for any_of
|
#include <thrust/logical.h> // for any_of
|
||||||
|
|
||||||
#include "categorical.h"
|
#include "categorical.h"
|
||||||
|
#include "cuda_context.cuh" // for CUDAContext
|
||||||
#include "device_helpers.cuh"
|
#include "device_helpers.cuh"
|
||||||
#include "error_msg.h" // for InvalidMaxBin
|
#include "error_msg.h" // for InvalidMaxBin
|
||||||
#include "quantile.h"
|
#include "quantile.h"
|
||||||
@ -127,7 +128,7 @@ class SketchContainer {
|
|||||||
/* \brief Whether the predictor matrix contains categorical features. */
|
/* \brief Whether the predictor matrix contains categorical features. */
|
||||||
bool HasCategorical() const { return has_categorical_; }
|
bool HasCategorical() const { return has_categorical_; }
|
||||||
/* \brief Accumulate weights of duplicated entries in input. */
|
/* \brief Accumulate weights of duplicated entries in input. */
|
||||||
size_t ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_columns_ptr_in);
|
size_t ScanInput(Context const* ctx, Span<SketchEntry> entries, Span<OffsetT> d_columns_ptr_in);
|
||||||
/* Fix rounding error and re-establish invariance. The error is mostly generated by the
|
/* Fix rounding error and re-establish invariance. The error is mostly generated by the
|
||||||
* addition inside `RMinNext` and subtraction in `RMaxPrev`. */
|
* addition inside `RMinNext` and subtraction in `RMaxPrev`. */
|
||||||
void FixError();
|
void FixError();
|
||||||
@ -140,19 +141,18 @@ class SketchContainer {
|
|||||||
* \param total_cuts Total number of cuts, equal to the back of cuts_ptr.
|
* \param total_cuts Total number of cuts, equal to the back of cuts_ptr.
|
||||||
* \param weights (optional) data weights.
|
* \param weights (optional) data weights.
|
||||||
*/
|
*/
|
||||||
void Push(Span<Entry const> entries, Span<size_t> columns_ptr,
|
void Push(Context const* ctx, Span<Entry const> entries, Span<size_t> columns_ptr,
|
||||||
common::Span<OffsetT> cuts_ptr, size_t total_cuts,
|
common::Span<OffsetT> cuts_ptr, size_t total_cuts, Span<float> weights = {});
|
||||||
Span<float> weights = {});
|
|
||||||
/* \brief Prune the quantile structure.
|
/* \brief Prune the quantile structure.
|
||||||
*
|
*
|
||||||
* \param to The maximum size of pruned quantile. If the size of quantile
|
* \param to The maximum size of pruned quantile. If the size of quantile
|
||||||
* structure is already less than `to`, then no operation is performed.
|
* structure is already less than `to`, then no operation is performed.
|
||||||
*/
|
*/
|
||||||
void Prune(size_t to);
|
void Prune(Context const* ctx, size_t to);
|
||||||
/* \brief Merge another set of sketch.
|
/* \brief Merge another set of sketch.
|
||||||
* \param that columns of other.
|
* \param that columns of other.
|
||||||
*/
|
*/
|
||||||
void Merge(Span<OffsetT const> that_columns_ptr,
|
void Merge(Context const* ctx, Span<OffsetT const> that_columns_ptr,
|
||||||
Span<SketchEntry const> that);
|
Span<SketchEntry const> that);
|
||||||
|
|
||||||
/* \brief Merge quantiles from other GPU workers. */
|
/* \brief Merge quantiles from other GPU workers. */
|
||||||
@ -175,7 +175,7 @@ class SketchContainer {
|
|||||||
|
|
||||||
/* \brief Removes all the duplicated elements in quantile structure. */
|
/* \brief Removes all the duplicated elements in quantile structure. */
|
||||||
template <typename KeyComp = thrust::equal_to<size_t>>
|
template <typename KeyComp = thrust::equal_to<size_t>>
|
||||||
size_t Unique(KeyComp key_comp = thrust::equal_to<size_t>{}) {
|
size_t Unique(Context const* ctx, KeyComp key_comp = thrust::equal_to<size_t>{}) {
|
||||||
timer_.Start(__func__);
|
timer_.Start(__func__);
|
||||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
||||||
this->columns_ptr_.SetDevice(device_);
|
this->columns_ptr_.SetDevice(device_);
|
||||||
@ -185,14 +185,12 @@ class SketchContainer {
|
|||||||
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
|
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
|
||||||
scan_out.SetDevice(device_);
|
scan_out.SetDevice(device_);
|
||||||
auto d_scan_out = scan_out.DeviceSpan();
|
auto d_scan_out = scan_out.DeviceSpan();
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
|
||||||
|
|
||||||
d_column_scan = this->columns_ptr_.DeviceSpan();
|
d_column_scan = this->columns_ptr_.DeviceSpan();
|
||||||
size_t n_uniques = dh::SegmentedUnique(
|
size_t n_uniques = dh::SegmentedUnique(
|
||||||
thrust::cuda::par(alloc), d_column_scan.data(),
|
ctx->CUDACtx()->CTP(), d_column_scan.data(), d_column_scan.data() + d_column_scan.size(),
|
||||||
d_column_scan.data() + d_column_scan.size(), entries.data(),
|
entries.data(), entries.data() + entries.size(), scan_out.DevicePointer(), entries.data(),
|
||||||
entries.data() + entries.size(), scan_out.DevicePointer(),
|
detail::SketchUnique{}, key_comp);
|
||||||
entries.data(), detail::SketchUnique{}, key_comp);
|
|
||||||
this->columns_ptr_.Copy(scan_out);
|
this->columns_ptr_.Copy(scan_out);
|
||||||
CHECK(!this->columns_ptr_.HostCanRead());
|
CHECK(!this->columns_ptr_.HostCanRead());
|
||||||
|
|
||||||
|
|||||||
@ -11,6 +11,7 @@
|
|||||||
#include <type_traits> // for invoke_result_t, declval
|
#include <type_traits> // for invoke_result_t, declval
|
||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
#include "../common/cuda_rt_utils.h" // for xgboost_NVTX_FN_RANGE
|
||||||
#include "adapter.h"
|
#include "adapter.h"
|
||||||
#include "xgboost/c_api.h"
|
#include "xgboost/c_api.h"
|
||||||
#include "xgboost/context.h"
|
#include "xgboost/context.h"
|
||||||
@ -36,6 +37,8 @@ class DataIterProxy {
|
|||||||
DataIterProxy& operator=(DataIterProxy const& that) = default;
|
DataIterProxy& operator=(DataIterProxy const& that) = default;
|
||||||
|
|
||||||
[[nodiscard]] bool Next() {
|
[[nodiscard]] bool Next() {
|
||||||
|
xgboost_NVTX_FN_RANGE();
|
||||||
|
|
||||||
bool ret = !!next_(iter_);
|
bool ret = !!next_(iter_);
|
||||||
if (!ret) {
|
if (!ret) {
|
||||||
return ret;
|
return ret;
|
||||||
|
|||||||
@ -30,14 +30,13 @@ void MakeSketches(Context const* ctx,
|
|||||||
ExternalDataInfo* p_ext_info) {
|
ExternalDataInfo* p_ext_info) {
|
||||||
xgboost_NVTX_FN_RANGE();
|
xgboost_NVTX_FN_RANGE();
|
||||||
|
|
||||||
CUDAContext const* cuctx = ctx->CUDACtx();
|
|
||||||
std::unique_ptr<common::SketchContainer> sketch;
|
std::unique_ptr<common::SketchContainer> sketch;
|
||||||
auto& ext_info = *p_ext_info;
|
auto& ext_info = *p_ext_info;
|
||||||
|
|
||||||
do {
|
do {
|
||||||
// We use do while here as the first batch is fetched in ctor
|
// We use do while here as the first batch is fetched in ctor
|
||||||
CHECK_LT(ctx->Ordinal(), common::AllVisibleGPUs());
|
CHECK_LT(ctx->Ordinal(), common::AllVisibleGPUs());
|
||||||
dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal));
|
common::SetDevice(dh::GetDevice(ctx).ordinal);
|
||||||
if (ext_info.n_features == 0) {
|
if (ext_info.n_features == 0) {
|
||||||
ext_info.n_features = data::BatchColumns(proxy);
|
ext_info.n_features = data::BatchColumns(proxy);
|
||||||
auto rc = collective::Allreduce(ctx, linalg::MakeVec(&ext_info.n_features, 1),
|
auto rc = collective::Allreduce(ctx, linalg::MakeVec(&ext_info.n_features, 1),
|
||||||
@ -55,7 +54,16 @@ void MakeSketches(Context const* ctx,
|
|||||||
}
|
}
|
||||||
proxy->Info().weights_.SetDevice(dh::GetDevice(ctx));
|
proxy->Info().weights_.SetDevice(dh::GetDevice(ctx));
|
||||||
cuda_impl::Dispatch(proxy, [&](auto const& value) {
|
cuda_impl::Dispatch(proxy, [&](auto const& value) {
|
||||||
common::AdapterDeviceSketch(value, p.max_bin, proxy->Info(), missing, sketch.get());
|
// Workaround empty input with CPU ctx.
|
||||||
|
Context new_ctx;
|
||||||
|
Context const* p_ctx;
|
||||||
|
if (ctx->IsCUDA()) {
|
||||||
|
p_ctx = ctx;
|
||||||
|
} else {
|
||||||
|
new_ctx.UpdateAllowUnknown(Args{{"device", dh::GetDevice(ctx).Name()}});
|
||||||
|
p_ctx = &new_ctx;
|
||||||
|
}
|
||||||
|
common::AdapterDeviceSketch(p_ctx, value, p.max_bin, proxy->Info(), missing, sketch.get());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
auto batch_rows = data::BatchSamples(proxy);
|
auto batch_rows = data::BatchSamples(proxy);
|
||||||
@ -66,7 +74,7 @@ void MakeSketches(Context const* ctx,
|
|||||||
std::max(ext_info.row_stride, cuda_impl::Dispatch(proxy, [=](auto const& value) {
|
std::max(ext_info.row_stride, cuda_impl::Dispatch(proxy, [=](auto const& value) {
|
||||||
return GetRowCounts(value, row_counts_span, dh::GetDevice(ctx), missing);
|
return GetRowCounts(value, row_counts_span, dh::GetDevice(ctx), missing);
|
||||||
}));
|
}));
|
||||||
ext_info.nnz += thrust::reduce(cuctx->CTP(), row_counts.begin(), row_counts.end());
|
ext_info.nnz += thrust::reduce(ctx->CUDACtx()->CTP(), row_counts.begin(), row_counts.end());
|
||||||
ext_info.n_batches++;
|
ext_info.n_batches++;
|
||||||
ext_info.base_rows.push_back(batch_rows);
|
ext_info.base_rows.push_back(batch_rows);
|
||||||
} while (iter->Next());
|
} while (iter->Next());
|
||||||
@ -77,7 +85,7 @@ void MakeSketches(Context const* ctx,
|
|||||||
ext_info.base_rows.begin());
|
ext_info.base_rows.begin());
|
||||||
|
|
||||||
// Get reference
|
// Get reference
|
||||||
dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal));
|
common::SetDevice(dh::GetDevice(ctx).ordinal);
|
||||||
if (!ref) {
|
if (!ref) {
|
||||||
sketch->MakeCuts(ctx, cuts.get(), info.IsColumnSplit());
|
sketch->MakeCuts(ctx, cuts.get(), info.IsColumnSplit());
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -11,6 +11,7 @@
|
|||||||
|
|
||||||
#include "../common/device_helpers.cuh"
|
#include "../common/device_helpers.cuh"
|
||||||
#include "../common/error_msg.h" // for InfInData
|
#include "../common/error_msg.h" // for InfInData
|
||||||
|
#include "../common/algorithm.cuh" // for CopyIf
|
||||||
#include "device_adapter.cuh" // for NoInfInData
|
#include "device_adapter.cuh" // for NoInfInData
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
@ -27,15 +28,14 @@ struct COOToEntryOp {
|
|||||||
// Here the data is already correctly ordered and simply needs to be compacted
|
// Here the data is already correctly ordered and simply needs to be compacted
|
||||||
// to remove missing data
|
// to remove missing data
|
||||||
template <typename AdapterBatchT>
|
template <typename AdapterBatchT>
|
||||||
void CopyDataToDMatrix(AdapterBatchT batch, common::Span<Entry> data,
|
void CopyDataToDMatrix(AdapterBatchT batch, common::Span<Entry> data, float missing) {
|
||||||
float missing) {
|
|
||||||
auto counting = thrust::make_counting_iterator(0llu);
|
auto counting = thrust::make_counting_iterator(0llu);
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
|
||||||
COOToEntryOp<decltype(batch)> transform_op{batch};
|
COOToEntryOp<decltype(batch)> transform_op{batch};
|
||||||
thrust::transform_iterator<decltype(transform_op), decltype(counting)>
|
thrust::transform_iterator<decltype(transform_op), decltype(counting)> transform_iter(
|
||||||
transform_iter(counting, transform_op);
|
counting, transform_op);
|
||||||
auto begin_output = thrust::device_pointer_cast(data.data());
|
auto begin_output = thrust::device_pointer_cast(data.data());
|
||||||
dh::CopyIf(transform_iter, transform_iter + batch.Size(), begin_output,
|
auto ctx = Context{}.MakeCUDA(dh::CurrentDevice());
|
||||||
|
common::CopyIf(ctx.CUDACtx(), transform_iter, transform_iter + batch.Size(), begin_output,
|
||||||
IsValidFunctor(missing));
|
IsValidFunctor(missing));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -9,8 +9,10 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../../../src/common/cuda_context.cuh"
|
||||||
#include "../../../src/common/device_helpers.cuh"
|
#include "../../../src/common/device_helpers.cuh"
|
||||||
#include "../../../src/common/quantile.h"
|
#include "../../../src/common/quantile.h"
|
||||||
|
#include "../helpers.h"
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
TEST(SumReduce, Test) {
|
TEST(SumReduce, Test) {
|
||||||
@ -61,11 +63,11 @@ TEST(SegmentedUnique, Basic) {
|
|||||||
thrust::device_vector<xgboost::bst_feature_t> d_segs_out(d_segments.size());
|
thrust::device_vector<xgboost::bst_feature_t> d_segs_out(d_segments.size());
|
||||||
thrust::device_vector<float> d_vals_out(d_values.size());
|
thrust::device_vector<float> d_vals_out(d_values.size());
|
||||||
|
|
||||||
|
auto ctx = xgboost::MakeCUDACtx(0);
|
||||||
size_t n_uniques = dh::SegmentedUnique(
|
size_t n_uniques = dh::SegmentedUnique(
|
||||||
d_segments.data().get(), d_segments.data().get() + d_segments.size(),
|
ctx.CUDACtx()->CTP(), d_segments.data().get(), d_segments.data().get() + d_segments.size(),
|
||||||
d_values.data().get(), d_values.data().get() + d_values.size(),
|
d_values.data().get(), d_values.data().get() + d_values.size(), d_segs_out.data().get(),
|
||||||
d_segs_out.data().get(), d_vals_out.data().get(),
|
d_vals_out.data().get(), thrust::equal_to<float>{});
|
||||||
thrust::equal_to<float>{});
|
|
||||||
CHECK_EQ(n_uniques, 5);
|
CHECK_EQ(n_uniques, 5);
|
||||||
|
|
||||||
std::vector<float> values_sol{0.1f, 0.2f, 0.3f, 0.62448811531066895f, 0.4f};
|
std::vector<float> values_sol{0.1f, 0.2f, 0.3f, 0.62448811531066895f, 0.4f};
|
||||||
@ -81,10 +83,9 @@ TEST(SegmentedUnique, Basic) {
|
|||||||
d_segments[1] = 4;
|
d_segments[1] = 4;
|
||||||
d_segments[2] = 6;
|
d_segments[2] = 6;
|
||||||
n_uniques = dh::SegmentedUnique(
|
n_uniques = dh::SegmentedUnique(
|
||||||
d_segments.data().get(), d_segments.data().get() + d_segments.size(),
|
ctx.CUDACtx()->CTP(), d_segments.data().get(), d_segments.data().get() + d_segments.size(),
|
||||||
d_values.data().get(), d_values.data().get() + d_values.size(),
|
d_values.data().get(), d_values.data().get() + d_values.size(), d_segs_out.data().get(),
|
||||||
d_segs_out.data().get(), d_vals_out.data().get(),
|
d_vals_out.data().get(), thrust::equal_to<float>{});
|
||||||
thrust::equal_to<float>{});
|
|
||||||
ASSERT_EQ(n_uniques, values.size());
|
ASSERT_EQ(n_uniques, values.size());
|
||||||
for (size_t i = 0 ; i < values.size(); i ++) {
|
for (size_t i = 0 ; i < values.size(); i ++) {
|
||||||
ASSERT_EQ(d_vals_out[i], values[i]);
|
ASSERT_EQ(d_vals_out[i], values[i]);
|
||||||
@ -113,10 +114,12 @@ void TestSegmentedUniqueRegression(std::vector<SketchEntry> values, size_t n_dup
|
|||||||
thrust::device_vector<bst_feature_t> d_segments(segments);
|
thrust::device_vector<bst_feature_t> d_segments(segments);
|
||||||
thrust::device_vector<bst_feature_t> d_segments_out(segments.size());
|
thrust::device_vector<bst_feature_t> d_segments_out(segments.size());
|
||||||
|
|
||||||
|
auto ctx = xgboost::MakeCUDACtx(0);
|
||||||
|
|
||||||
size_t n_uniques = dh::SegmentedUnique(
|
size_t n_uniques = dh::SegmentedUnique(
|
||||||
d_segments.data().get(), d_segments.data().get() + d_segments.size(), d_values.data().get(),
|
ctx.CUDACtx()->CTP(), d_segments.data().get(), d_segments.data().get() + d_segments.size(),
|
||||||
d_values.data().get() + d_values.size(), d_segments_out.data().get(), d_values.data().get(),
|
d_values.data().get(), d_values.data().get() + d_values.size(), d_segments_out.data().get(),
|
||||||
SketchUnique{});
|
d_values.data().get(), SketchUnique{});
|
||||||
ASSERT_EQ(n_uniques, values.size() - n_duplicated);
|
ASSERT_EQ(n_uniques, values.size() - n_duplicated);
|
||||||
ASSERT_TRUE(thrust::is_sorted(thrust::device, d_values.begin(),
|
ASSERT_TRUE(thrust::is_sorted(thrust::device, d_values.begin(),
|
||||||
d_values.begin() + n_uniques, IsSorted{}));
|
d_values.begin() + n_uniques, IsSorted{}));
|
||||||
|
|||||||
@ -221,8 +221,8 @@ TEST(HistUtil, RemoveDuplicatedCategories) {
|
|||||||
thrust::sort_by_key(sorted_entries.begin(), sorted_entries.end(), weight.begin(),
|
thrust::sort_by_key(sorted_entries.begin(), sorted_entries.end(), weight.begin(),
|
||||||
detail::EntryCompareOp());
|
detail::EntryCompareOp());
|
||||||
|
|
||||||
detail::RemoveDuplicatedCategories(ctx.Device(), info, cuts_ptr.DeviceSpan(), &sorted_entries,
|
detail::RemoveDuplicatedCategories(&ctx, info, cuts_ptr.DeviceSpan(), &sorted_entries, &weight,
|
||||||
&weight, &columns_ptr);
|
&columns_ptr);
|
||||||
|
|
||||||
auto const& h_cptr = cuts_ptr.ConstHostVector();
|
auto const& h_cptr = cuts_ptr.ConstHostVector();
|
||||||
ASSERT_EQ(h_cptr.back(), n_samples * 2 + n_categories);
|
ASSERT_EQ(h_cptr.back(), n_samples * 2 + n_categories);
|
||||||
@ -367,7 +367,7 @@ auto MakeUnweightedCutsForTest(Context const* ctx, Adapter adapter, int32_t num_
|
|||||||
SketchContainer sketch_container(ft, num_bins, adapter.NumColumns(), adapter.NumRows(),
|
SketchContainer sketch_container(ft, num_bins, adapter.NumColumns(), adapter.NumRows(),
|
||||||
DeviceOrd::CUDA(0));
|
DeviceOrd::CUDA(0));
|
||||||
MetaInfo info;
|
MetaInfo info;
|
||||||
AdapterDeviceSketch(adapter.Value(), num_bins, info, missing, &sketch_container, batch_size);
|
AdapterDeviceSketch(ctx, adapter.Value(), num_bins, info, missing, &sketch_container, batch_size);
|
||||||
sketch_container.MakeCuts(ctx, &batched_cuts, info.IsColumnSplit());
|
sketch_container.MakeCuts(ctx, &batched_cuts, info.IsColumnSplit());
|
||||||
return batched_cuts;
|
return batched_cuts;
|
||||||
}
|
}
|
||||||
@ -437,8 +437,8 @@ TEST(HistUtil, AdapterSketchSlidingWindowMemory) {
|
|||||||
common::HistogramCuts batched_cuts;
|
common::HistogramCuts batched_cuts;
|
||||||
HostDeviceVector<FeatureType> ft;
|
HostDeviceVector<FeatureType> ft;
|
||||||
SketchContainer sketch_container(ft, num_bins, num_columns, num_rows, DeviceOrd::CUDA(0));
|
SketchContainer sketch_container(ft, num_bins, num_columns, num_rows, DeviceOrd::CUDA(0));
|
||||||
AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits<float>::quiet_NaN(),
|
AdapterDeviceSketch(&ctx, adapter.Value(), num_bins, info,
|
||||||
&sketch_container);
|
std::numeric_limits<float>::quiet_NaN(), &sketch_container);
|
||||||
HistogramCuts cuts;
|
HistogramCuts cuts;
|
||||||
sketch_container.MakeCuts(&ctx, &cuts, info.IsColumnSplit());
|
sketch_container.MakeCuts(&ctx, &cuts, info.IsColumnSplit());
|
||||||
size_t bytes_required = detail::RequiredMemory(
|
size_t bytes_required = detail::RequiredMemory(
|
||||||
@ -466,9 +466,8 @@ TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) {
|
|||||||
common::HistogramCuts batched_cuts;
|
common::HistogramCuts batched_cuts;
|
||||||
HostDeviceVector<FeatureType> ft;
|
HostDeviceVector<FeatureType> ft;
|
||||||
SketchContainer sketch_container(ft, num_bins, num_columns, num_rows, DeviceOrd::CUDA(0));
|
SketchContainer sketch_container(ft, num_bins, num_columns, num_rows, DeviceOrd::CUDA(0));
|
||||||
AdapterDeviceSketch(adapter.Value(), num_bins, info,
|
AdapterDeviceSketch(&ctx, adapter.Value(), num_bins, info,
|
||||||
std::numeric_limits<float>::quiet_NaN(),
|
std::numeric_limits<float>::quiet_NaN(), &sketch_container);
|
||||||
&sketch_container);
|
|
||||||
|
|
||||||
HistogramCuts cuts;
|
HistogramCuts cuts;
|
||||||
sketch_container.MakeCuts(&ctx, &cuts, info.IsColumnSplit());
|
sketch_container.MakeCuts(&ctx, &cuts, info.IsColumnSplit());
|
||||||
@ -502,7 +501,7 @@ void TestCategoricalSketchAdapter(size_t n, size_t num_categories,
|
|||||||
|
|
||||||
ASSERT_EQ(info.feature_types.Size(), 1);
|
ASSERT_EQ(info.feature_types.Size(), 1);
|
||||||
SketchContainer container(info.feature_types, num_bins, 1, n, DeviceOrd::CUDA(0));
|
SketchContainer container(info.feature_types, num_bins, 1, n, DeviceOrd::CUDA(0));
|
||||||
AdapterDeviceSketch(adapter.Value(), num_bins, info,
|
AdapterDeviceSketch(&ctx, adapter.Value(), num_bins, info,
|
||||||
std::numeric_limits<float>::quiet_NaN(), &container);
|
std::numeric_limits<float>::quiet_NaN(), &container);
|
||||||
HistogramCuts cuts;
|
HistogramCuts cuts;
|
||||||
container.MakeCuts(&ctx, &cuts, info.IsColumnSplit());
|
container.MakeCuts(&ctx, &cuts, info.IsColumnSplit());
|
||||||
@ -616,22 +615,27 @@ void TestGetColumnSize(std::size_t n_samples) {
|
|||||||
std::vector<std::size_t> h_column_size(column_sizes_scan.size());
|
std::vector<std::size_t> h_column_size(column_sizes_scan.size());
|
||||||
std::vector<std::size_t> h_column_size_1(column_sizes_scan.size());
|
std::vector<std::size_t> h_column_size_1(column_sizes_scan.size());
|
||||||
|
|
||||||
|
auto cuctx = ctx.CUDACtx();
|
||||||
detail::LaunchGetColumnSizeKernel<decltype(batch_iter), true, true>(
|
detail::LaunchGetColumnSizeKernel<decltype(batch_iter), true, true>(
|
||||||
ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan));
|
cuctx, ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid,
|
||||||
|
dh::ToSpan(column_sizes_scan));
|
||||||
thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size.begin());
|
thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size.begin());
|
||||||
|
|
||||||
detail::LaunchGetColumnSizeKernel<decltype(batch_iter), true, false>(
|
detail::LaunchGetColumnSizeKernel<decltype(batch_iter), true, false>(
|
||||||
ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan));
|
cuctx, ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid,
|
||||||
|
dh::ToSpan(column_sizes_scan));
|
||||||
thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size_1.begin());
|
thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size_1.begin());
|
||||||
ASSERT_EQ(h_column_size, h_column_size_1);
|
ASSERT_EQ(h_column_size, h_column_size_1);
|
||||||
|
|
||||||
detail::LaunchGetColumnSizeKernel<decltype(batch_iter), false, true>(
|
detail::LaunchGetColumnSizeKernel<decltype(batch_iter), false, true>(
|
||||||
ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan));
|
cuctx, ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid,
|
||||||
|
dh::ToSpan(column_sizes_scan));
|
||||||
thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size_1.begin());
|
thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size_1.begin());
|
||||||
ASSERT_EQ(h_column_size, h_column_size_1);
|
ASSERT_EQ(h_column_size, h_column_size_1);
|
||||||
|
|
||||||
detail::LaunchGetColumnSizeKernel<decltype(batch_iter), false, false>(
|
detail::LaunchGetColumnSizeKernel<decltype(batch_iter), false, false>(
|
||||||
ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan));
|
cuctx, ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid,
|
||||||
|
dh::ToSpan(column_sizes_scan));
|
||||||
thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size_1.begin());
|
thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size_1.begin());
|
||||||
ASSERT_EQ(h_column_size, h_column_size_1);
|
ASSERT_EQ(h_column_size, h_column_size_1);
|
||||||
}
|
}
|
||||||
@ -737,7 +741,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
|||||||
auto const& batch = adapter.Value();
|
auto const& batch = adapter.Value();
|
||||||
HostDeviceVector<FeatureType> ft;
|
HostDeviceVector<FeatureType> ft;
|
||||||
SketchContainer sketch_container(ft, kBins, kCols, kRows, DeviceOrd::CUDA(0));
|
SketchContainer sketch_container(ft, kBins, kCols, kRows, DeviceOrd::CUDA(0));
|
||||||
AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
|
AdapterDeviceSketch(&ctx, adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
|
||||||
&sketch_container);
|
&sketch_container);
|
||||||
|
|
||||||
common::HistogramCuts cuts;
|
common::HistogramCuts cuts;
|
||||||
@ -780,7 +784,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
|||||||
h_weights[i] = (i % 2 == 0 ? 1 : 2) / static_cast<float>(kGroups);
|
h_weights[i] = (i % 2 == 0 ? 1 : 2) / static_cast<float>(kGroups);
|
||||||
}
|
}
|
||||||
SketchContainer sketch_container{ft, kBins, kCols, kRows, DeviceOrd::CUDA(0)};
|
SketchContainer sketch_container{ft, kBins, kCols, kRows, DeviceOrd::CUDA(0)};
|
||||||
AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
|
AdapterDeviceSketch(&ctx, adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
|
||||||
&sketch_container);
|
&sketch_container);
|
||||||
sketch_container.MakeCuts(&ctx, &weighted, info.IsColumnSplit());
|
sketch_container.MakeCuts(&ctx, &weighted, info.IsColumnSplit());
|
||||||
ValidateCuts(weighted, dmat.get(), kBins);
|
ValidateCuts(weighted, dmat.get(), kBins);
|
||||||
|
|||||||
@ -24,14 +24,15 @@ namespace common {
|
|||||||
class MGPUQuantileTest : public collective::BaseMGPUTest {};
|
class MGPUQuantileTest : public collective::BaseMGPUTest {};
|
||||||
|
|
||||||
TEST(GPUQuantile, Basic) {
|
TEST(GPUQuantile, Basic) {
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
constexpr size_t kRows = 1000, kCols = 100, kBins = 256;
|
constexpr size_t kRows = 1000, kCols = 100, kBins = 256;
|
||||||
HostDeviceVector<FeatureType> ft;
|
HostDeviceVector<FeatureType> ft;
|
||||||
SketchContainer sketch(ft, kBins, kCols, kRows, FstCU());
|
SketchContainer sketch(ft, kBins, kCols, kRows, ctx.Device());
|
||||||
dh::caching_device_vector<Entry> entries;
|
dh::caching_device_vector<Entry> entries;
|
||||||
dh::device_vector<bst_idx_t> cuts_ptr(kCols+1);
|
dh::device_vector<bst_idx_t> cuts_ptr(kCols+1);
|
||||||
thrust::fill(cuts_ptr.begin(), cuts_ptr.end(), 0);
|
thrust::fill(cuts_ptr.begin(), cuts_ptr.end(), 0);
|
||||||
// Push empty
|
// Push empty
|
||||||
sketch.Push(dh::ToSpan(entries), dh::ToSpan(cuts_ptr), dh::ToSpan(cuts_ptr), 0);
|
sketch.Push(&ctx, dh::ToSpan(entries), dh::ToSpan(cuts_ptr), dh::ToSpan(cuts_ptr), 0);
|
||||||
ASSERT_EQ(sketch.Data().size(), 0);
|
ASSERT_EQ(sketch.Data().size(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -39,16 +40,17 @@ void TestSketchUnique(float sparsity) {
|
|||||||
constexpr size_t kRows = 1000, kCols = 100;
|
constexpr size_t kRows = 1000, kCols = 100;
|
||||||
RunWithSeedsAndBins(kRows, [kRows, kCols, sparsity](std::int32_t seed, bst_bin_t n_bins,
|
RunWithSeedsAndBins(kRows, [kRows, kCols, sparsity](std::int32_t seed, bst_bin_t n_bins,
|
||||||
MetaInfo const& info) {
|
MetaInfo const& info) {
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
HostDeviceVector<FeatureType> ft;
|
HostDeviceVector<FeatureType> ft;
|
||||||
SketchContainer sketch(ft, n_bins, kCols, kRows, FstCU());
|
SketchContainer sketch(ft, n_bins, kCols, kRows, ctx.Device());
|
||||||
|
|
||||||
HostDeviceVector<float> storage;
|
HostDeviceVector<float> storage;
|
||||||
std::string interface_str = RandomDataGenerator{kRows, kCols, sparsity}
|
std::string interface_str = RandomDataGenerator{kRows, kCols, sparsity}
|
||||||
.Seed(seed)
|
.Seed(seed)
|
||||||
.Device(FstCU())
|
.Device(ctx.Device())
|
||||||
.GenerateArrayInterface(&storage);
|
.GenerateArrayInterface(&storage);
|
||||||
data::CupyAdapter adapter(interface_str);
|
data::CupyAdapter adapter(interface_str);
|
||||||
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
AdapterDeviceSketch(&ctx, adapter.Value(), n_bins, info,
|
||||||
std::numeric_limits<float>::quiet_NaN(), &sketch);
|
std::numeric_limits<float>::quiet_NaN(), &sketch);
|
||||||
auto n_cuts = detail::RequiredSampleCutsPerColumn(n_bins, kRows);
|
auto n_cuts = detail::RequiredSampleCutsPerColumn(n_bins, kRows);
|
||||||
|
|
||||||
@ -60,8 +62,9 @@ void TestSketchUnique(float sparsity) {
|
|||||||
thrust::make_counting_iterator(0llu),
|
thrust::make_counting_iterator(0llu),
|
||||||
[=] __device__(size_t idx) { return batch.GetElement(idx); });
|
[=] __device__(size_t idx) { return batch.GetElement(idx); });
|
||||||
auto end = kCols * kRows;
|
auto end = kCols * kRows;
|
||||||
detail::GetColumnSizesScan(FstCU(), kCols, n_cuts, IterSpan{batch_iter, end}, is_valid,
|
detail::GetColumnSizesScan(ctx.CUDACtx(), ctx.Device(), kCols, n_cuts,
|
||||||
&cut_sizes_scan, &column_sizes_scan);
|
IterSpan{batch_iter, end}, is_valid, &cut_sizes_scan,
|
||||||
|
&column_sizes_scan);
|
||||||
auto const& cut_sizes = cut_sizes_scan.HostVector();
|
auto const& cut_sizes = cut_sizes_scan.HostVector();
|
||||||
ASSERT_LE(sketch.Data().size(), cut_sizes.back());
|
ASSERT_LE(sketch.Data().size(), cut_sizes.back());
|
||||||
|
|
||||||
@ -69,7 +72,7 @@ void TestSketchUnique(float sparsity) {
|
|||||||
dh::CopyDeviceSpanToVector(&h_columns_ptr, sketch.ColumnsPtr());
|
dh::CopyDeviceSpanToVector(&h_columns_ptr, sketch.ColumnsPtr());
|
||||||
ASSERT_EQ(sketch.Data().size(), h_columns_ptr.back());
|
ASSERT_EQ(sketch.Data().size(), h_columns_ptr.back());
|
||||||
|
|
||||||
sketch.Unique();
|
sketch.Unique(&ctx);
|
||||||
|
|
||||||
std::vector<SketchEntry> h_data(sketch.Data().size());
|
std::vector<SketchEntry> h_data(sketch.Data().size());
|
||||||
thrust::copy(dh::tcbegin(sketch.Data()), dh::tcend(sketch.Data()), h_data.begin());
|
thrust::copy(dh::tcbegin(sketch.Data()), dh::tcend(sketch.Data()), h_data.begin());
|
||||||
@ -124,44 +127,46 @@ void TestQuantileElemRank(DeviceOrd device, Span<SketchEntry const> in,
|
|||||||
TEST(GPUQuantile, Prune) {
|
TEST(GPUQuantile, Prune) {
|
||||||
constexpr size_t kRows = 1000, kCols = 100;
|
constexpr size_t kRows = 1000, kCols = 100;
|
||||||
RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) {
|
RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) {
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
HostDeviceVector<FeatureType> ft;
|
HostDeviceVector<FeatureType> ft;
|
||||||
SketchContainer sketch(ft, n_bins, kCols, kRows, FstCU());
|
SketchContainer sketch(ft, n_bins, kCols, kRows, ctx.Device());
|
||||||
|
|
||||||
HostDeviceVector<float> storage;
|
HostDeviceVector<float> storage;
|
||||||
std::string interface_str =
|
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
|
||||||
RandomDataGenerator{kRows, kCols, 0}.Device(FstCU()).Seed(seed).GenerateArrayInterface(
|
.Device(ctx.Device())
|
||||||
&storage);
|
.Seed(seed)
|
||||||
|
.GenerateArrayInterface(&storage);
|
||||||
data::CupyAdapter adapter(interface_str);
|
data::CupyAdapter adapter(interface_str);
|
||||||
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
AdapterDeviceSketch(&ctx, adapter.Value(), n_bins, info,
|
||||||
std::numeric_limits<float>::quiet_NaN(), &sketch);
|
std::numeric_limits<float>::quiet_NaN(), &sketch);
|
||||||
auto n_cuts = detail::RequiredSampleCutsPerColumn(n_bins, kRows);
|
auto n_cuts = detail::RequiredSampleCutsPerColumn(n_bins, kRows);
|
||||||
// LE because kRows * kCols is pushed into sketch, after removing
|
// LE because kRows * kCols is pushed into sketch, after removing
|
||||||
// duplicated entries we might not have that much inputs for prune.
|
// duplicated entries we might not have that much inputs for prune.
|
||||||
ASSERT_LE(sketch.Data().size(), n_cuts * kCols);
|
ASSERT_LE(sketch.Data().size(), n_cuts * kCols);
|
||||||
|
|
||||||
sketch.Prune(n_bins);
|
sketch.Prune(&ctx, n_bins);
|
||||||
ASSERT_LE(sketch.Data().size(), kRows * kCols);
|
ASSERT_LE(sketch.Data().size(), kRows * kCols);
|
||||||
// This is not necessarily true for all inputs without calling unique after
|
// This is not necessarily true for all inputs without calling unique after
|
||||||
// prune.
|
// prune.
|
||||||
ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch.Data().data(),
|
ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch.Data().data(),
|
||||||
sketch.Data().data() + sketch.Data().size(),
|
sketch.Data().data() + sketch.Data().size(),
|
||||||
detail::SketchUnique{}));
|
detail::SketchUnique{}));
|
||||||
TestQuantileElemRank(FstCU(), sketch.Data(), sketch.ColumnsPtr());
|
TestQuantileElemRank(ctx.Device(), sketch.Data(), sketch.ColumnsPtr());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GPUQuantile, MergeEmpty) {
|
TEST(GPUQuantile, MergeEmpty) {
|
||||||
constexpr size_t kRows = 1000, kCols = 100;
|
constexpr size_t kRows = 1000, kCols = 100;
|
||||||
size_t n_bins = 10;
|
size_t n_bins = 10;
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
HostDeviceVector<FeatureType> ft;
|
HostDeviceVector<FeatureType> ft;
|
||||||
SketchContainer sketch_0(ft, n_bins, kCols, kRows, FstCU());
|
SketchContainer sketch_0(ft, n_bins, kCols, kRows, ctx.Device());
|
||||||
HostDeviceVector<float> storage_0;
|
HostDeviceVector<float> storage_0;
|
||||||
std::string interface_str_0 =
|
std::string interface_str_0 =
|
||||||
RandomDataGenerator{kRows, kCols, 0}.Device(FstCU()).GenerateArrayInterface(
|
RandomDataGenerator{kRows, kCols, 0}.Device(ctx.Device()).GenerateArrayInterface(&storage_0);
|
||||||
&storage_0);
|
|
||||||
data::CupyAdapter adapter_0(interface_str_0);
|
data::CupyAdapter adapter_0(interface_str_0);
|
||||||
MetaInfo info;
|
MetaInfo info;
|
||||||
AdapterDeviceSketch(adapter_0.Value(), n_bins, info,
|
AdapterDeviceSketch(&ctx, adapter_0.Value(), n_bins, info,
|
||||||
std::numeric_limits<float>::quiet_NaN(), &sketch_0);
|
std::numeric_limits<float>::quiet_NaN(), &sketch_0);
|
||||||
|
|
||||||
std::vector<SketchEntry> entries_before(sketch_0.Data().size());
|
std::vector<SketchEntry> entries_before(sketch_0.Data().size());
|
||||||
@ -170,7 +175,7 @@ TEST(GPUQuantile, MergeEmpty) {
|
|||||||
dh::CopyDeviceSpanToVector(&ptrs_before, sketch_0.ColumnsPtr());
|
dh::CopyDeviceSpanToVector(&ptrs_before, sketch_0.ColumnsPtr());
|
||||||
thrust::device_vector<size_t> columns_ptr(kCols + 1);
|
thrust::device_vector<size_t> columns_ptr(kCols + 1);
|
||||||
// Merge an empty sketch
|
// Merge an empty sketch
|
||||||
sketch_0.Merge(dh::ToSpan(columns_ptr), Span<SketchEntry>{});
|
sketch_0.Merge(&ctx, dh::ToSpan(columns_ptr), Span<SketchEntry>{});
|
||||||
|
|
||||||
std::vector<SketchEntry> entries_after(sketch_0.Data().size());
|
std::vector<SketchEntry> entries_after(sketch_0.Data().size());
|
||||||
dh::CopyDeviceSpanToVector(&entries_after, sketch_0.Data());
|
dh::CopyDeviceSpanToVector(&entries_after, sketch_0.Data());
|
||||||
@ -193,34 +198,36 @@ TEST(GPUQuantile, MergeEmpty) {
|
|||||||
TEST(GPUQuantile, MergeBasic) {
|
TEST(GPUQuantile, MergeBasic) {
|
||||||
constexpr size_t kRows = 1000, kCols = 100;
|
constexpr size_t kRows = 1000, kCols = 100;
|
||||||
RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) {
|
RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) {
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
HostDeviceVector<FeatureType> ft;
|
HostDeviceVector<FeatureType> ft;
|
||||||
SketchContainer sketch_0(ft, n_bins, kCols, kRows, FstCU());
|
SketchContainer sketch_0(ft, n_bins, kCols, kRows, ctx.Device());
|
||||||
HostDeviceVector<float> storage_0;
|
HostDeviceVector<float> storage_0;
|
||||||
std::string interface_str_0 = RandomDataGenerator{kRows, kCols, 0}
|
std::string interface_str_0 = RandomDataGenerator{kRows, kCols, 0}
|
||||||
.Device(FstCU())
|
.Device(ctx.Device())
|
||||||
.Seed(seed)
|
.Seed(seed)
|
||||||
.GenerateArrayInterface(&storage_0);
|
.GenerateArrayInterface(&storage_0);
|
||||||
data::CupyAdapter adapter_0(interface_str_0);
|
data::CupyAdapter adapter_0(interface_str_0);
|
||||||
AdapterDeviceSketch(adapter_0.Value(), n_bins, info,
|
AdapterDeviceSketch(&ctx, adapter_0.Value(), n_bins, info,
|
||||||
std::numeric_limits<float>::quiet_NaN(), &sketch_0);
|
std::numeric_limits<float>::quiet_NaN(), &sketch_0);
|
||||||
|
|
||||||
SketchContainer sketch_1(ft, n_bins, kCols, kRows * kRows, FstCU());
|
SketchContainer sketch_1(ft, n_bins, kCols, kRows * kRows, ctx.Device());
|
||||||
HostDeviceVector<float> storage_1;
|
HostDeviceVector<float> storage_1;
|
||||||
std::string interface_str_1 =
|
std::string interface_str_1 = RandomDataGenerator{kRows, kCols, 0}
|
||||||
RandomDataGenerator{kRows, kCols, 0}.Device(FstCU()).Seed(seed).GenerateArrayInterface(
|
.Device(ctx.Device())
|
||||||
&storage_1);
|
.Seed(seed)
|
||||||
|
.GenerateArrayInterface(&storage_1);
|
||||||
data::CupyAdapter adapter_1(interface_str_1);
|
data::CupyAdapter adapter_1(interface_str_1);
|
||||||
AdapterDeviceSketch(adapter_1.Value(), n_bins, info, std::numeric_limits<float>::quiet_NaN(),
|
AdapterDeviceSketch(&ctx, adapter_1.Value(), n_bins, info,
|
||||||
&sketch_1);
|
std::numeric_limits<float>::quiet_NaN(), &sketch_1);
|
||||||
|
|
||||||
size_t size_before_merge = sketch_0.Data().size();
|
size_t size_before_merge = sketch_0.Data().size();
|
||||||
sketch_0.Merge(sketch_1.ColumnsPtr(), sketch_1.Data());
|
sketch_0.Merge(&ctx, sketch_1.ColumnsPtr(), sketch_1.Data());
|
||||||
if (info.weights_.Size() != 0) {
|
if (info.weights_.Size() != 0) {
|
||||||
TestQuantileElemRank(FstCU(), sketch_0.Data(), sketch_0.ColumnsPtr(), true);
|
TestQuantileElemRank(ctx.Device(), sketch_0.Data(), sketch_0.ColumnsPtr(), true);
|
||||||
sketch_0.FixError();
|
sketch_0.FixError();
|
||||||
TestQuantileElemRank(FstCU(), sketch_0.Data(), sketch_0.ColumnsPtr(), false);
|
TestQuantileElemRank(ctx.Device(), sketch_0.Data(), sketch_0.ColumnsPtr(), false);
|
||||||
} else {
|
} else {
|
||||||
TestQuantileElemRank(FstCU(), sketch_0.Data(), sketch_0.ColumnsPtr());
|
TestQuantileElemRank(ctx.Device(), sketch_0.Data(), sketch_0.ColumnsPtr());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto columns_ptr = sketch_0.ColumnsPtr();
|
auto columns_ptr = sketch_0.ColumnsPtr();
|
||||||
@ -228,7 +235,7 @@ TEST(GPUQuantile, MergeBasic) {
|
|||||||
dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr);
|
dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr);
|
||||||
ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge);
|
ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge);
|
||||||
|
|
||||||
sketch_0.Unique();
|
sketch_0.Unique(&ctx);
|
||||||
ASSERT_TRUE(
|
ASSERT_TRUE(
|
||||||
thrust::is_sorted(thrust::device, sketch_0.Data().data(),
|
thrust::is_sorted(thrust::device, sketch_0.Data().data(),
|
||||||
sketch_0.Data().data() + sketch_0.Data().size(),
|
sketch_0.Data().data() + sketch_0.Data().size(),
|
||||||
@ -237,25 +244,27 @@ TEST(GPUQuantile, MergeBasic) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) {
|
void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) {
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
MetaInfo info;
|
MetaInfo info;
|
||||||
int32_t seed = 0;
|
int32_t seed = 0;
|
||||||
HostDeviceVector<FeatureType> ft;
|
HostDeviceVector<FeatureType> ft;
|
||||||
SketchContainer sketch_0(ft, n_bins, cols, rows, FstCU());
|
SketchContainer sketch_0(ft, n_bins, cols, rows, ctx.Device());
|
||||||
HostDeviceVector<float> storage_0;
|
HostDeviceVector<float> storage_0;
|
||||||
std::string interface_str_0 =
|
std::string interface_str_0 = RandomDataGenerator{rows, cols, 0}
|
||||||
RandomDataGenerator{rows, cols, 0}.Device(FstCU()).Seed(seed).GenerateArrayInterface(
|
.Device(ctx.Device())
|
||||||
&storage_0);
|
.Seed(seed)
|
||||||
|
.GenerateArrayInterface(&storage_0);
|
||||||
data::CupyAdapter adapter_0(interface_str_0);
|
data::CupyAdapter adapter_0(interface_str_0);
|
||||||
AdapterDeviceSketch(adapter_0.Value(), n_bins, info,
|
AdapterDeviceSketch(&ctx, adapter_0.Value(), n_bins, info,
|
||||||
std::numeric_limits<float>::quiet_NaN(),
|
std::numeric_limits<float>::quiet_NaN(), &sketch_0);
|
||||||
&sketch_0);
|
|
||||||
|
|
||||||
size_t f_rows = rows * frac;
|
size_t f_rows = rows * frac;
|
||||||
SketchContainer sketch_1(ft, n_bins, cols, f_rows, FstCU());
|
SketchContainer sketch_1(ft, n_bins, cols, f_rows, ctx.Device());
|
||||||
HostDeviceVector<float> storage_1;
|
HostDeviceVector<float> storage_1;
|
||||||
std::string interface_str_1 =
|
std::string interface_str_1 = RandomDataGenerator{f_rows, cols, 0}
|
||||||
RandomDataGenerator{f_rows, cols, 0}.Device(FstCU()).Seed(seed).GenerateArrayInterface(
|
.Device(ctx.Device())
|
||||||
&storage_1);
|
.Seed(seed)
|
||||||
|
.GenerateArrayInterface(&storage_1);
|
||||||
auto data_1 = storage_1.DeviceSpan();
|
auto data_1 = storage_1.DeviceSpan();
|
||||||
auto tuple_it = thrust::make_tuple(
|
auto tuple_it = thrust::make_tuple(
|
||||||
thrust::make_counting_iterator<size_t>(0ul), data_1.data());
|
thrust::make_counting_iterator<size_t>(0ul), data_1.data());
|
||||||
@ -271,20 +280,19 @@ void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
data::CupyAdapter adapter_1(interface_str_1);
|
data::CupyAdapter adapter_1(interface_str_1);
|
||||||
AdapterDeviceSketch(adapter_1.Value(), n_bins, info,
|
AdapterDeviceSketch(&ctx, adapter_1.Value(), n_bins, info,
|
||||||
std::numeric_limits<float>::quiet_NaN(),
|
std::numeric_limits<float>::quiet_NaN(), &sketch_1);
|
||||||
&sketch_1);
|
|
||||||
|
|
||||||
size_t size_before_merge = sketch_0.Data().size();
|
size_t size_before_merge = sketch_0.Data().size();
|
||||||
sketch_0.Merge(sketch_1.ColumnsPtr(), sketch_1.Data());
|
sketch_0.Merge(&ctx, sketch_1.ColumnsPtr(), sketch_1.Data());
|
||||||
TestQuantileElemRank(FstCU(), sketch_0.Data(), sketch_0.ColumnsPtr());
|
TestQuantileElemRank(ctx.Device(), sketch_0.Data(), sketch_0.ColumnsPtr());
|
||||||
|
|
||||||
auto columns_ptr = sketch_0.ColumnsPtr();
|
auto columns_ptr = sketch_0.ColumnsPtr();
|
||||||
std::vector<bst_idx_t> h_columns_ptr(columns_ptr.size());
|
std::vector<bst_idx_t> h_columns_ptr(columns_ptr.size());
|
||||||
dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr);
|
dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr);
|
||||||
ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge);
|
ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge);
|
||||||
|
|
||||||
sketch_0.Unique();
|
sketch_0.Unique(&ctx);
|
||||||
columns_ptr = sketch_0.ColumnsPtr();
|
columns_ptr = sketch_0.ColumnsPtr();
|
||||||
dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr);
|
dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr);
|
||||||
|
|
||||||
@ -311,7 +319,8 @@ TEST(GPUQuantile, MultiMerge) {
|
|||||||
RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) {
|
RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) {
|
||||||
// Set up single node version
|
// Set up single node version
|
||||||
HostDeviceVector<FeatureType> ft;
|
HostDeviceVector<FeatureType> ft;
|
||||||
SketchContainer sketch_on_single_node(ft, n_bins, kCols, kRows, FstCU());
|
auto ctx = MakeCUDACtx(0);
|
||||||
|
SketchContainer sketch_on_single_node(ft, n_bins, kCols, kRows, ctx.Device());
|
||||||
|
|
||||||
size_t intermediate_num_cuts = std::min(
|
size_t intermediate_num_cuts = std::min(
|
||||||
kRows * world, static_cast<size_t>(n_bins * WQSketch::kFactor));
|
kRows * world, static_cast<size_t>(n_bins * WQSketch::kFactor));
|
||||||
@ -319,25 +328,26 @@ TEST(GPUQuantile, MultiMerge) {
|
|||||||
for (auto rank = 0; rank < world; ++rank) {
|
for (auto rank = 0; rank < world; ++rank) {
|
||||||
HostDeviceVector<float> storage;
|
HostDeviceVector<float> storage;
|
||||||
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
|
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
|
||||||
.Device(FstCU())
|
.Device(ctx.Device())
|
||||||
.Seed(rank + seed)
|
.Seed(rank + seed)
|
||||||
.GenerateArrayInterface(&storage);
|
.GenerateArrayInterface(&storage);
|
||||||
data::CupyAdapter adapter(interface_str);
|
data::CupyAdapter adapter(interface_str);
|
||||||
HostDeviceVector<FeatureType> ft;
|
HostDeviceVector<FeatureType> ft;
|
||||||
containers.emplace_back(ft, n_bins, kCols, kRows, FstCU());
|
containers.emplace_back(ft, n_bins, kCols, kRows, ctx.Device());
|
||||||
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
AdapterDeviceSketch(&ctx, adapter.Value(), n_bins, info,
|
||||||
std::numeric_limits<float>::quiet_NaN(),
|
std::numeric_limits<float>::quiet_NaN(), &containers.back());
|
||||||
&containers.back());
|
|
||||||
}
|
}
|
||||||
for (auto &sketch : containers) {
|
for (auto &sketch : containers) {
|
||||||
sketch.Prune(intermediate_num_cuts);
|
sketch.Prune(&ctx, intermediate_num_cuts);
|
||||||
sketch_on_single_node.Merge(sketch.ColumnsPtr(), sketch.Data());
|
sketch_on_single_node.Merge(&ctx, sketch.ColumnsPtr(), sketch.Data());
|
||||||
sketch_on_single_node.FixError();
|
sketch_on_single_node.FixError();
|
||||||
}
|
}
|
||||||
TestQuantileElemRank(FstCU(), sketch_on_single_node.Data(), sketch_on_single_node.ColumnsPtr());
|
TestQuantileElemRank(ctx.Device(), sketch_on_single_node.Data(),
|
||||||
|
sketch_on_single_node.ColumnsPtr());
|
||||||
|
|
||||||
sketch_on_single_node.Unique();
|
sketch_on_single_node.Unique(&ctx);
|
||||||
TestQuantileElemRank(FstCU(), sketch_on_single_node.Data(), sketch_on_single_node.ColumnsPtr());
|
TestQuantileElemRank(ctx.Device(), sketch_on_single_node.Data(),
|
||||||
|
sketch_on_single_node.ColumnsPtr());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -392,15 +402,15 @@ void TestAllReduceBasic() {
|
|||||||
data::CupyAdapter adapter(interface_str);
|
data::CupyAdapter adapter(interface_str);
|
||||||
HostDeviceVector<FeatureType> ft({}, device);
|
HostDeviceVector<FeatureType> ft({}, device);
|
||||||
containers.emplace_back(ft, n_bins, kCols, kRows, device);
|
containers.emplace_back(ft, n_bins, kCols, kRows, device);
|
||||||
AdapterDeviceSketch(adapter.Value(), n_bins, info, std::numeric_limits<float>::quiet_NaN(),
|
AdapterDeviceSketch(&ctx, adapter.Value(), n_bins, info,
|
||||||
&containers.back());
|
std::numeric_limits<float>::quiet_NaN(), &containers.back());
|
||||||
}
|
}
|
||||||
for (auto& sketch : containers) {
|
for (auto& sketch : containers) {
|
||||||
sketch.Prune(intermediate_num_cuts);
|
sketch.Prune(&ctx, intermediate_num_cuts);
|
||||||
sketch_on_single_node.Merge(sketch.ColumnsPtr(), sketch.Data());
|
sketch_on_single_node.Merge(&ctx, sketch.ColumnsPtr(), sketch.Data());
|
||||||
sketch_on_single_node.FixError();
|
sketch_on_single_node.FixError();
|
||||||
}
|
}
|
||||||
sketch_on_single_node.Unique();
|
sketch_on_single_node.Unique(&ctx);
|
||||||
TestQuantileElemRank(device, sketch_on_single_node.Data(), sketch_on_single_node.ColumnsPtr(),
|
TestQuantileElemRank(device, sketch_on_single_node.Data(), sketch_on_single_node.ColumnsPtr(),
|
||||||
true);
|
true);
|
||||||
|
|
||||||
@ -416,16 +426,16 @@ void TestAllReduceBasic() {
|
|||||||
.Seed(rank + seed)
|
.Seed(rank + seed)
|
||||||
.GenerateArrayInterface(&storage);
|
.GenerateArrayInterface(&storage);
|
||||||
data::CupyAdapter adapter(interface_str);
|
data::CupyAdapter adapter(interface_str);
|
||||||
AdapterDeviceSketch(adapter.Value(), n_bins, info, std::numeric_limits<float>::quiet_NaN(),
|
AdapterDeviceSketch(&ctx, adapter.Value(), n_bins, info,
|
||||||
&sketch_distributed);
|
std::numeric_limits<float>::quiet_NaN(), &sketch_distributed);
|
||||||
if (world == 1) {
|
if (world == 1) {
|
||||||
auto n_samples_global = kRows * world;
|
auto n_samples_global = kRows * world;
|
||||||
intermediate_num_cuts =
|
intermediate_num_cuts =
|
||||||
std::min(n_samples_global, static_cast<size_t>(n_bins * SketchContainer::kFactor));
|
std::min(n_samples_global, static_cast<size_t>(n_bins * SketchContainer::kFactor));
|
||||||
sketch_distributed.Prune(intermediate_num_cuts);
|
sketch_distributed.Prune(&ctx, intermediate_num_cuts);
|
||||||
}
|
}
|
||||||
sketch_distributed.AllReduce(&ctx, false);
|
sketch_distributed.AllReduce(&ctx, false);
|
||||||
sketch_distributed.Unique();
|
sketch_distributed.Unique(&ctx);
|
||||||
|
|
||||||
ASSERT_EQ(sketch_distributed.ColumnsPtr().size(), sketch_on_single_node.ColumnsPtr().size());
|
ASSERT_EQ(sketch_distributed.ColumnsPtr().size(), sketch_on_single_node.ColumnsPtr().size());
|
||||||
ASSERT_EQ(sketch_distributed.Data().size(), sketch_on_single_node.Data().size());
|
ASSERT_EQ(sketch_distributed.Data().size(), sketch_on_single_node.Data().size());
|
||||||
@ -535,11 +545,10 @@ void TestSameOnAllWorkers() {
|
|||||||
.Seed(rank + seed)
|
.Seed(rank + seed)
|
||||||
.GenerateArrayInterface(&storage);
|
.GenerateArrayInterface(&storage);
|
||||||
data::CupyAdapter adapter(interface_str);
|
data::CupyAdapter adapter(interface_str);
|
||||||
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
AdapterDeviceSketch(&ctx, adapter.Value(), n_bins, info,
|
||||||
std::numeric_limits<float>::quiet_NaN(),
|
std::numeric_limits<float>::quiet_NaN(), &sketch_distributed);
|
||||||
&sketch_distributed);
|
|
||||||
sketch_distributed.AllReduce(&ctx, false);
|
sketch_distributed.AllReduce(&ctx, false);
|
||||||
sketch_distributed.Unique();
|
sketch_distributed.Unique(&ctx);
|
||||||
TestQuantileElemRank(device, sketch_distributed.Data(), sketch_distributed.ColumnsPtr(), true);
|
TestQuantileElemRank(device, sketch_distributed.Data(), sketch_distributed.ColumnsPtr(), true);
|
||||||
|
|
||||||
// Test for all workers having the same sketch.
|
// Test for all workers having the same sketch.
|
||||||
@ -547,16 +556,13 @@ void TestSameOnAllWorkers() {
|
|||||||
auto rc = collective::Allreduce(&ctx, linalg::MakeVec(&n_data, 1), collective::Op::kMax);
|
auto rc = collective::Allreduce(&ctx, linalg::MakeVec(&n_data, 1), collective::Op::kMax);
|
||||||
SafeColl(rc);
|
SafeColl(rc);
|
||||||
ASSERT_EQ(n_data, sketch_distributed.Data().size());
|
ASSERT_EQ(n_data, sketch_distributed.Data().size());
|
||||||
size_t size_as_float =
|
size_t size_as_float = sketch_distributed.Data().size_bytes() / sizeof(float);
|
||||||
sketch_distributed.Data().size_bytes() / sizeof(float);
|
|
||||||
auto local_data = Span<float const>{
|
auto local_data = Span<float const>{
|
||||||
reinterpret_cast<float const *>(sketch_distributed.Data().data()),
|
reinterpret_cast<float const*>(sketch_distributed.Data().data()), size_as_float};
|
||||||
size_as_float};
|
|
||||||
|
|
||||||
dh::caching_device_vector<float> all_workers(size_as_float * world);
|
dh::caching_device_vector<float> all_workers(size_as_float * world);
|
||||||
thrust::fill(all_workers.begin(), all_workers.end(), 0);
|
thrust::fill(all_workers.begin(), all_workers.end(), 0);
|
||||||
thrust::copy(thrust::device, local_data.data(),
|
thrust::copy(thrust::device, local_data.data(), local_data.data() + local_data.size(),
|
||||||
local_data.data() + local_data.size(),
|
|
||||||
all_workers.begin() + local_data.size() * rank);
|
all_workers.begin() + local_data.size() * rank);
|
||||||
rc = collective::Allreduce(
|
rc = collective::Allreduce(
|
||||||
&ctx, linalg::MakeVec(all_workers.data().get(), all_workers.size(), ctx.Device()),
|
&ctx, linalg::MakeVec(all_workers.data().get(), all_workers.size(), ctx.Device()),
|
||||||
@ -590,6 +596,7 @@ TEST_F(MGPUQuantileTest, SameOnAllWorkers) {
|
|||||||
TEST(GPUQuantile, Push) {
|
TEST(GPUQuantile, Push) {
|
||||||
size_t constexpr kRows = 100;
|
size_t constexpr kRows = 100;
|
||||||
std::vector<float> data(kRows);
|
std::vector<float> data(kRows);
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
|
|
||||||
std::fill(data.begin(), data.begin() + (data.size() / 2), 0.3f);
|
std::fill(data.begin(), data.begin() + (data.size() / 2), 0.3f);
|
||||||
std::fill(data.begin() + (data.size() / 2), data.end(), 0.5f);
|
std::fill(data.begin() + (data.size() / 2), data.end(), 0.5f);
|
||||||
@ -608,8 +615,8 @@ TEST(GPUQuantile, Push) {
|
|||||||
columns_ptr[1] = kRows;
|
columns_ptr[1] = kRows;
|
||||||
|
|
||||||
HostDeviceVector<FeatureType> ft;
|
HostDeviceVector<FeatureType> ft;
|
||||||
SketchContainer sketch(ft, n_bins, kCols, kRows, FstCU());
|
SketchContainer sketch(ft, n_bins, kCols, kRows, ctx.Device());
|
||||||
sketch.Push(dh::ToSpan(d_entries), dh::ToSpan(columns_ptr), dh::ToSpan(columns_ptr), kRows, {});
|
sketch.Push(&ctx, dh::ToSpan(d_entries), dh::ToSpan(columns_ptr), dh::ToSpan(columns_ptr), kRows, {});
|
||||||
|
|
||||||
auto sketch_data = sketch.Data();
|
auto sketch_data = sketch.Data();
|
||||||
|
|
||||||
@ -633,9 +640,9 @@ TEST(GPUQuantile, Push) {
|
|||||||
TEST(GPUQuantile, MultiColPush) {
|
TEST(GPUQuantile, MultiColPush) {
|
||||||
size_t constexpr kRows = 100, kCols = 4;
|
size_t constexpr kRows = 100, kCols = 4;
|
||||||
std::vector<float> data(kRows * kCols);
|
std::vector<float> data(kRows * kCols);
|
||||||
|
|
||||||
std::fill(data.begin(), data.begin() + (data.size() / 2), 0.3f);
|
std::fill(data.begin(), data.begin() + (data.size() / 2), 0.3f);
|
||||||
|
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
std::vector<Entry> entries(kRows * kCols);
|
std::vector<Entry> entries(kRows * kCols);
|
||||||
|
|
||||||
for (bst_feature_t c = 0; c < kCols; ++c) {
|
for (bst_feature_t c = 0; c < kCols; ++c) {
|
||||||
@ -648,7 +655,7 @@ TEST(GPUQuantile, MultiColPush) {
|
|||||||
|
|
||||||
int32_t n_bins = 16;
|
int32_t n_bins = 16;
|
||||||
HostDeviceVector<FeatureType> ft;
|
HostDeviceVector<FeatureType> ft;
|
||||||
SketchContainer sketch(ft, n_bins, kCols, kRows, FstCU());
|
SketchContainer sketch(ft, n_bins, kCols, kRows, ctx.Device());
|
||||||
dh::device_vector<Entry> d_entries {entries};
|
dh::device_vector<Entry> d_entries {entries};
|
||||||
|
|
||||||
dh::device_vector<size_t> columns_ptr(kCols + 1, 0);
|
dh::device_vector<size_t> columns_ptr(kCols + 1, 0);
|
||||||
@ -659,8 +666,8 @@ TEST(GPUQuantile, MultiColPush) {
|
|||||||
columns_ptr.begin());
|
columns_ptr.begin());
|
||||||
dh::device_vector<size_t> cuts_ptr(columns_ptr);
|
dh::device_vector<size_t> cuts_ptr(columns_ptr);
|
||||||
|
|
||||||
sketch.Push(dh::ToSpan(d_entries), dh::ToSpan(columns_ptr),
|
sketch.Push(&ctx, dh::ToSpan(d_entries), dh::ToSpan(columns_ptr), dh::ToSpan(cuts_ptr),
|
||||||
dh::ToSpan(cuts_ptr), kRows * kCols, {});
|
kRows * kCols, {});
|
||||||
|
|
||||||
auto sketch_data = sketch.Data();
|
auto sketch_data = sketch.Data();
|
||||||
ASSERT_EQ(sketch_data.size(), kCols * 2);
|
ASSERT_EQ(sketch_data.size(), kCols * 2);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user